diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index a62d41b37..62956ed37 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -676,7 +676,7 @@ def __init__( "cpu" if device is None else device.type ) # TODO: replace hardcoded cpu with DEFAULT_DEVICE_TYPE in torchrec.distributed.types when torch package issue resolved else: - # If no device is provided, use "cuda". + # BUG: device will default to cuda if cpu specified self._device_type: str = ( device.type if device is not None and device.type in {"meta", "cuda", "mtia"} diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 8d07a04b7..dd7029459 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -319,7 +319,7 @@ def bucketize_kjt_inference( num_buckets=num_buckets, block_sizes=block_sizes_new_type, bucketize_pos=bucketize_pos, - block_bucketize_pos=block_bucketize_row_pos, # each tensor should have the same dtype as kjt.lengths() + block_bucketize_pos=block_bucketize_row_pos, ) else: ( diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index 36c1c2586..df3d6098a 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -31,6 +31,7 @@ ) from torchrec.distributed.sharding.rw_sharding import ( BaseRwEmbeddingSharding, + get_embedding_shard_metadata, InferRwSparseFeaturesDist, RwSparseFeaturesDist, ) @@ -39,7 +40,6 @@ SequenceShardingContext, ) from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs -from torchrec.distributed.utils import none_throws from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -199,16 +199,9 @@ def create_input_dist( num_features = self._get_num_features() feature_hash_sizes = self._get_feature_hash_sizes() - emb_sharding = [] - for embedding_table_group in self._grouped_embedding_configs_per_rank[0]: - for table in embedding_table_group.embedding_tables: - shard_split_offsets = [ - shard.shard_offsets[0] - for shard in none_throws(table.global_metadata).shards_metadata - ] - shard_split_offsets.append(none_throws(table.global_metadata).size[0]) - emb_sharding.extend([shard_split_offsets] * len(table.embedding_names)) - + (emb_sharding, is_even_sharding) = get_embedding_shard_metadata( + self._grouped_embedding_configs_per_rank + ) return InferRwSparseFeaturesDist( world_size=self._world_size, num_features=num_features, @@ -217,7 +210,7 @@ def create_input_dist( is_sequence=True, has_feature_processor=self._has_feature_processor, need_pos=False, - embedding_shard_metadata=emb_sharding, + embedding_shard_metadata=emb_sharding if not is_even_sharding else None, ) def create_lookup( diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index a59ce87d5..fc80ab17e 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -77,6 +77,12 @@ PositionWeightedModuleCollection, ) from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection +from torchrec.modules.mc_modules import ( + DistanceLFU_EvictionPolicy, + ManagedCollisionCollection, + MCHManagedCollisionModule, +) torch.fx.wrap("len") @@ -135,6 +141,10 @@ def placement_helper(device_type: str, index: int = 0) -> str: class InferShardingsTest(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + set_propogate_device(True) + @unittest.skipIf( torch.cuda.device_count() <= 1, "Not enough GPUs available", @@ -146,7 +156,6 @@ class InferShardingsTest(unittest.TestCase): ) @settings(max_examples=4, deadline=None) def test_tw(self, weight_dtype: torch.dtype, device_type: str) -> None: - set_propogate_device(True) num_embeddings = 256 emb_dim = 16 world_size = 2 @@ -217,7 +226,6 @@ def test_tw(self, weight_dtype: torch.dtype, device_type: str) -> None: def test_tw_ebc_full_rank_weighted_ebc_with_empty_rank( self, weight_dtype: torch.dtype, device_type: str ) -> None: - set_propogate_device(True) num_embeddings = 256 emb_dim = 16 world_size = 2 @@ -288,7 +296,6 @@ def test_tw_ebc_full_rank_weighted_ebc_with_empty_rank( ) @settings(max_examples=4, deadline=None) def test_rw(self, weight_dtype: torch.dtype, device_type: str) -> None: - set_propogate_device(True) num_embeddings = 256 emb_dim = 16 world_size = 2 @@ -362,7 +369,6 @@ def test_rw(self, weight_dtype: torch.dtype, device_type: str) -> None: def test_cw( self, test_permute: bool, weight_dtype: torch.dtype, device_type: str ) -> None: - set_propogate_device(True) num_embeddings = 64 emb_dim = 512 emb_dim_4 = emb_dim // 4 @@ -479,7 +485,6 @@ def test_cw_with_smaller_emb_dim( weight_dtype: torch.dtype, device_type: str, ) -> None: - set_propogate_device(True) num_embeddings = 64 emb_dim_4 = emb_dim // 4 world_size = 2 @@ -566,7 +571,6 @@ def test_cw_with_smaller_emb_dim( def test_cw_multiple_tables_with_permute( self, weight_dtype: torch.dtype, device_type: str ) -> None: - set_propogate_device(True) num_embeddings = 64 emb_dim = 512 emb_dim_2 = 512 // 2 @@ -678,7 +682,6 @@ def test_cw_multiple_tables_with_permute( def test_cw_irregular_shard_placement( self, weight_dtype: torch.dtype, device_type: str ) -> None: - set_propogate_device(True) num_embeddings = 64 emb_dim = 384 emb_dim_2 = emb_dim // 2 @@ -802,7 +805,6 @@ def test_cw_irregular_shard_placement( def test_cw_sequence( self, device_type_weight_dtype: Tuple[str, torch.dtype] ) -> None: - set_propogate_device(True) device_type, weight_dtype = device_type_weight_dtype num_embeddings = 4 emb_dim = 512 @@ -893,6 +895,7 @@ def test_cw_sequence( ] sharded_model.load_state_dict(non_sharded_model.state_dict()) + sharded_output = sharded_model(*inputs[0]) non_sharded_output = non_sharded_model(*inputs[0]) assert_close(sharded_output, non_sharded_output) @@ -918,7 +921,6 @@ def test_cw_sequence( ) @settings(max_examples=4, deadline=None) def test_tw_sequence(self, weight_dtype: torch.dtype, device_type: str) -> None: - set_propogate_device(True) num_embeddings = 10 emb_dim = 16 world_size = 2 @@ -1040,7 +1042,6 @@ def test_tw_sequence(self, weight_dtype: torch.dtype, device_type: str) -> None: def test_rw_sequence( self, device_type_weight_dtype: Tuple[str, torch.dtype] ) -> None: - set_propogate_device(True) device_type, weight_dtype = device_type_weight_dtype num_embeddings = 10 emb_dim = 16 @@ -1151,7 +1152,6 @@ def test_rw_sequence( ) @settings(max_examples=4, deadline=None) def test_rw_sequence_uneven(self, weight_dtype: torch.dtype, device: str) -> None: - set_propogate_device(True) num_embeddings = 512 emb_dim = 64 world_size = 4 @@ -1353,7 +1353,6 @@ def test_rw_sequence_uneven(self, weight_dtype: torch.dtype, device: str) -> Non def test_mix_tw_rw_sequence( self, weight_dtype: torch.dtype, device_type: str ) -> None: - set_propogate_device(True) num_embeddings = 10 emb_dim = 16 world_size = 2 @@ -1465,7 +1464,6 @@ def test_mix_tw_rw_sequence( def test_mix_tw_rw_sequence_missing_feature_on_rank( self, weight_dtype: torch.dtype, device_type: str ) -> None: - set_propogate_device(True) num_embeddings = 10 emb_dim = 16 world_size = 2 @@ -1585,7 +1583,6 @@ def test_rw_uneven_sharding( uneven_shard_pattern: Tuple[int, int, int, int], device: str, ) -> None: - set_propogate_device(True) num_embeddings, size0, size1, size2 = uneven_shard_pattern size2 = min(size2, num_embeddings - size0 - size1) emb_dim = 64 @@ -1680,7 +1677,6 @@ def test_rw_uneven_sharding_mutiple_table( weight_dtype: torch.dtype, device: str, ) -> None: - set_propogate_device(True) num_embeddings = 512 emb_dim = 64 local_size = 4 @@ -1837,7 +1833,6 @@ def test_mix_sharding_mutiple_table( weight_dtype: torch.dtype, device: str, ) -> None: - set_propogate_device(True) num_embeddings = 512 emb_dim = 64 local_size = 4 @@ -1919,8 +1914,7 @@ def test_sharded_quant_fp_ebc_tw( world_size = 2 batch_size = 2 local_device = torch.device(f"{device_type}:0") - - topology: Topology = Topology(world_size=world_size, compute_device="cuda") + topology: Topology = Topology(world_size=world_size, compute_device=device_type) mi = TestModelInfo( dense_device=local_device, sparse_device=local_device, @@ -2003,7 +1997,7 @@ def test_sharded_quant_fp_ebc_tw( print(f"quant_model:\n{quant_model}") non_sharded_output = mi.quant_model(*inputs[0]) - topology: Topology = Topology(world_size=world_size, compute_device="cuda") + topology: Topology = Topology(world_size=world_size, compute_device=device_type) mi.planner = EmbeddingShardingPlanner( topology=topology, batch_size=batch_size, @@ -2075,6 +2069,163 @@ def test_sharded_quant_fp_ebc_tw( gm_script_output = gm_script(*inputs[0]) assert_close(sharded_output, gm_script_output) + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + weight_dtype=st.sampled_from([torch.qint8]), + device_type=st.sampled_from(["cpu", "cuda"]), + ) + @settings(max_examples=2, deadline=None) + def test_sharded_quant_mc_ec_rw( + self, weight_dtype: torch.dtype, device_type: str + ) -> None: + num_embeddings = 10 + emb_dim = 16 + world_size = 2 + batch_size = 2 + local_device = torch.device(f"{device_type}:0") + + topology: Topology = Topology(world_size=world_size, compute_device=device_type) + mi = TestModelInfo( + dense_device=local_device, + sparse_device=local_device, + num_features=1, + num_float_features=10, + num_weighted_features=0, + topology=topology, + ) + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + mi.tables = [ + EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(mi.num_features) + ] + + mi.model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + ManagedCollisionEmbeddingCollection( + EmbeddingCollection( + tables=mi.tables, + device=mi.sparse_device, + ), + ManagedCollisionCollection( + managed_collision_modules={ + "table_0": MCHManagedCollisionModule( + zch_size=num_embeddings, + input_hash_size=4000, + device=mi.sparse_device, + eviction_interval=2, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + }, + # pyre-ignore [6] Incompatible parameter type + embedding_configs=mi.tables, + ), + ) + ) + ) + model_inputs: List[ModelInput] = prep_inputs( + mi, world_size, batch_size, long_indices=True + ) + inputs = [] + for model_input in model_inputs: + kjt = model_input.idlist_features + kjt = kjt.to(local_device) + weights = None + inputs.append( + ( + kjt._keys, + kjt._values, + weights, + kjt._lengths, + kjt._offsets, + ) + ) + + mi.model(*inputs[0]) + print(f"model:\n{mi.model}") + assert mi.model.training is True + mi.quant_model = quantize( + module=mi.model, + inplace=False, + register_tbes=False, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + ) + quant_model = mi.quant_model + assert quant_model.training is False + print(f"quant_model:\n{quant_model}") + non_sharded_output, _ = mi.quant_model(*inputs[0]) + + topology: Topology = Topology(world_size=world_size, compute_device=device_type) + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + sharder = QuantEmbeddingCollectionSharder() + # pyre-ignore + plan = mi.planner.plan( + mi.quant_model, + [sharder], + ) + + sharded_model = shard_qec( + mi, + sharding_type=ShardingType.ROW_WISE, + device=local_device, + expected_shards=None, + plan=plan, + ) + + print(f"sharded_model:\n{sharded_model}") + for n, m in sharded_model.named_modules(): + print(f"sharded_model.MODULE[{n}]:{type(m)}") + + sharded_model.load_state_dict(quant_model.state_dict()) + sharded_output, _ = sharded_model(*inputs[0]) + + assert_close(non_sharded_output, sharded_output) + gm: torch.fx.GraphModule = symbolic_trace( + sharded_model, + leaf_modules=[ + "IntNBitTableBatchedEmbeddingBagsCodegen", + "ComputeJTDictToKJT", + ], + ) + + print(f"fx.graph:\n{gm.graph}") + gm_script = torch.jit.script(gm) + print(f"gm_script:\n{gm_script}") + gm_script_output, _ = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + @unittest.skipIf( torch.cuda.device_count() <= 1, "Not enough GPUs available", @@ -2160,7 +2311,9 @@ def test_sharded_quant_fp_ebc_tw_meta(self, compute_device: str) -> None: quant_model = mi.quant_model print(f"quant_model:\n{quant_model}") - topology: Topology = Topology(world_size=world_size, compute_device="cuda") + topology: Topology = Topology( + world_size=world_size, compute_device=compute_device + ) mi.planner = EmbeddingShardingPlanner( topology=topology, batch_size=batch_size,