From 5e6ebb9c6488bdbb58c0d8a3a94ef1d4d15c852c Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Wed, 16 Oct 2024 11:34:49 +0000 Subject: [PATCH] 2024-10-16 nightly release (d317c0b048686914b8e6ddeb02320ac4fefba954) --- .../distributed/batched_embedding_kernel.py | 65 +++---------------- .../distributed/sharding/grid_sharding.py | 19 ++++-- 2 files changed, 25 insertions(+), 59 deletions(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 8500e21fa..0da8df7d8 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -211,42 +211,6 @@ class ShardParams: local_metadata: List[ShardMetadata] embedding_weights: List[torch.Tensor] - def get_optimizer_single_value_shard_metadata_and_global_metadata( - table_global_metadata: ShardedTensorMetadata, - optimizer_state: torch.Tensor, - ) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]: - table_global_shards_metadata: List[ShardMetadata] = ( - table_global_metadata.shards_metadata - ) - - table_shard_metadata_to_optimizer_shard_metadata = {} - for offset, table_shard_metadata in enumerate(table_global_shards_metadata): - table_shard_metadata_to_optimizer_shard_metadata[ - table_shard_metadata - ] = ShardMetadata( - shard_sizes=[1], # single value optimizer state - shard_offsets=[offset], # offset increases by 1 for each shard - placement=table_shard_metadata.placement, - ) - - tensor_properties = TensorProperties( - dtype=optimizer_state.dtype, - layout=optimizer_state.layout, - requires_grad=False, - ) - single_value_optimizer_st_metadata = ShardedTensorMetadata( - shards_metadata=list( - table_shard_metadata_to_optimizer_shard_metadata.values() - ), - size=torch.Size([len(table_global_shards_metadata)]), - tensor_properties=tensor_properties, - ) - - return ( - table_shard_metadata_to_optimizer_shard_metadata, - single_value_optimizer_st_metadata, - ) - def get_optimizer_rowwise_shard_metadata_and_global_metadata( table_global_metadata: ShardedTensorMetadata, optimizer_state: torch.Tensor, @@ -392,10 +356,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( if optimizer_states: optimizer_state_values = tuple(optimizer_states.values()) for optimizer_state_value in optimizer_state_values: - assert ( - table_config.local_rows == optimizer_state_value.size(0) - or optimizer_state_value.nelement() == 1 # single value state - ) + assert table_config.local_rows == optimizer_state_value.size(0) optimizer_states_keys_by_table[table_config.name] = list( optimizer_states.keys() ) @@ -474,35 +435,29 @@ def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor: momentum_local_shards: List[Shard] = [] optimizer_sharded_tensor_metadata: ShardedTensorMetadata - optim_state = shard_params.optimizer_states[0][momentum_idx - 1] # pyre-ignore[16] - if optim_state.nelement() == 1: - # single value state: one value per table - ( - table_shard_metadata_to_optimizer_shard_metadata, - optimizer_sharded_tensor_metadata, - ) = get_optimizer_single_value_shard_metadata_and_global_metadata( - table_config.global_metadata, - optim_state, - ) - elif optim_state.dim() == 1: - # rowwise state: param.shape[0] == state.shape[0], state.shape[1] == 1 + is_rowwise_optimizer_state: bool = ( + # pyre-ignore + shard_params.optimizer_states[0][momentum_idx - 1].dim() + == 1 + ) + + if is_rowwise_optimizer_state: ( table_shard_metadata_to_optimizer_shard_metadata, optimizer_sharded_tensor_metadata, ) = get_optimizer_rowwise_shard_metadata_and_global_metadata( table_config.global_metadata, - optim_state, + shard_params.optimizer_states[0][momentum_idx - 1], sharding_dim, is_grid_sharded, ) else: - # pointwise state: param.shape == state.shape ( table_shard_metadata_to_optimizer_shard_metadata, optimizer_sharded_tensor_metadata, ) = get_optimizer_pointwise_shard_metadata_and_global_metadata( table_config.global_metadata, - optim_state, + shard_params.optimizer_states[0][momentum_idx - 1], ) for optimizer_state, table_shard_local_metadata in zip( diff --git a/torchrec/distributed/sharding/grid_sharding.py b/torchrec/distributed/sharding/grid_sharding.py index 2b57371d5..ef49cbb30 100644 --- a/torchrec/distributed/sharding/grid_sharding.py +++ b/torchrec/distributed/sharding/grid_sharding.py @@ -115,7 +115,8 @@ def __init__( def _init_combined_embeddings(self) -> None: """ - similar to CW sharding, but this time each CW shard is on a node and not rank + Initializes combined embeddings, similar to the CW sharding implementation, + but in this case the CW shard is treated on a per node basis and not per rank. """ embedding_names = [] for grouped_embedding_configs in self._grouped_embedding_configs_per_node: @@ -179,6 +180,17 @@ def _shard( self, sharding_infos: List[EmbeddingShardingInfo], ) -> List[List[ShardedEmbeddingTable]]: + """ + Shards the embedding tables. + This method takes the sharding infos and returns a list of lists of + sharded embedding tables, where each inner list represents the tables + for a specific rank. + + Args: + sharding_infos (List[EmbeddingShardingInfo]): The sharding infos. + Returns: + List[List[ShardedEmbeddingTable]]: The sharded embedding tables. + """ world_size = self._world_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ [] for i in range(world_size) @@ -198,7 +210,7 @@ def _shard( ), ) - # expectation is planner CW shards across a node, so each CW shard will have local_size num row shards + # Expectation is planner CW shards across a node, so each CW shard will have local_size number of row shards # pyre-fixme [6] for i, rank in enumerate(info.param_sharding.ranks): tables_per_rank[rank].append( @@ -212,7 +224,6 @@ def _shard( pooling=info.embedding_config.pooling, is_weighted=info.embedding_config.is_weighted, has_feature_processor=info.embedding_config.has_feature_processor, - # sharding by row and col local_rows=shards[i].shard_sizes[0], local_cols=shards[i].shard_sizes[1], compute_kernel=EmbeddingComputeKernel( @@ -420,7 +431,7 @@ class GridPooledEmbeddingSharding( ] ): """ - Shards embedding bags table-wise then row-wise. + Shards embedding bags into column wise shards and shards each CW shard table wise row wise within a node """ def create_input_dist(