Skip to content

Commit

Permalink
2024-10-16 nightly release (d317c0b)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Oct 16, 2024
1 parent 55ea80c commit 5e6ebb9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 59 deletions.
65 changes: 10 additions & 55 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,42 +211,6 @@ class ShardParams:
local_metadata: List[ShardMetadata]
embedding_weights: List[torch.Tensor]

def get_optimizer_single_value_shard_metadata_and_global_metadata(
table_global_metadata: ShardedTensorMetadata,
optimizer_state: torch.Tensor,
) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]:
table_global_shards_metadata: List[ShardMetadata] = (
table_global_metadata.shards_metadata
)

table_shard_metadata_to_optimizer_shard_metadata = {}
for offset, table_shard_metadata in enumerate(table_global_shards_metadata):
table_shard_metadata_to_optimizer_shard_metadata[
table_shard_metadata
] = ShardMetadata(
shard_sizes=[1], # single value optimizer state
shard_offsets=[offset], # offset increases by 1 for each shard
placement=table_shard_metadata.placement,
)

tensor_properties = TensorProperties(
dtype=optimizer_state.dtype,
layout=optimizer_state.layout,
requires_grad=False,
)
single_value_optimizer_st_metadata = ShardedTensorMetadata(
shards_metadata=list(
table_shard_metadata_to_optimizer_shard_metadata.values()
),
size=torch.Size([len(table_global_shards_metadata)]),
tensor_properties=tensor_properties,
)

return (
table_shard_metadata_to_optimizer_shard_metadata,
single_value_optimizer_st_metadata,
)

def get_optimizer_rowwise_shard_metadata_and_global_metadata(
table_global_metadata: ShardedTensorMetadata,
optimizer_state: torch.Tensor,
Expand Down Expand Up @@ -392,10 +356,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
if optimizer_states:
optimizer_state_values = tuple(optimizer_states.values())
for optimizer_state_value in optimizer_state_values:
assert (
table_config.local_rows == optimizer_state_value.size(0)
or optimizer_state_value.nelement() == 1 # single value state
)
assert table_config.local_rows == optimizer_state_value.size(0)
optimizer_states_keys_by_table[table_config.name] = list(
optimizer_states.keys()
)
Expand Down Expand Up @@ -474,35 +435,29 @@ def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor:
momentum_local_shards: List[Shard] = []
optimizer_sharded_tensor_metadata: ShardedTensorMetadata

optim_state = shard_params.optimizer_states[0][momentum_idx - 1] # pyre-ignore[16]
if optim_state.nelement() == 1:
# single value state: one value per table
(
table_shard_metadata_to_optimizer_shard_metadata,
optimizer_sharded_tensor_metadata,
) = get_optimizer_single_value_shard_metadata_and_global_metadata(
table_config.global_metadata,
optim_state,
)
elif optim_state.dim() == 1:
# rowwise state: param.shape[0] == state.shape[0], state.shape[1] == 1
is_rowwise_optimizer_state: bool = (
# pyre-ignore
shard_params.optimizer_states[0][momentum_idx - 1].dim()
== 1
)

if is_rowwise_optimizer_state:
(
table_shard_metadata_to_optimizer_shard_metadata,
optimizer_sharded_tensor_metadata,
) = get_optimizer_rowwise_shard_metadata_and_global_metadata(
table_config.global_metadata,
optim_state,
shard_params.optimizer_states[0][momentum_idx - 1],
sharding_dim,
is_grid_sharded,
)
else:
# pointwise state: param.shape == state.shape
(
table_shard_metadata_to_optimizer_shard_metadata,
optimizer_sharded_tensor_metadata,
) = get_optimizer_pointwise_shard_metadata_and_global_metadata(
table_config.global_metadata,
optim_state,
shard_params.optimizer_states[0][momentum_idx - 1],
)

for optimizer_state, table_shard_local_metadata in zip(
Expand Down
19 changes: 15 additions & 4 deletions torchrec/distributed/sharding/grid_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def __init__(

def _init_combined_embeddings(self) -> None:
"""
similar to CW sharding, but this time each CW shard is on a node and not rank
Initializes combined embeddings, similar to the CW sharding implementation,
but in this case the CW shard is treated on a per node basis and not per rank.
"""
embedding_names = []
for grouped_embedding_configs in self._grouped_embedding_configs_per_node:
Expand Down Expand Up @@ -179,6 +180,17 @@ def _shard(
self,
sharding_infos: List[EmbeddingShardingInfo],
) -> List[List[ShardedEmbeddingTable]]:
"""
Shards the embedding tables.
This method takes the sharding infos and returns a list of lists of
sharded embedding tables, where each inner list represents the tables
for a specific rank.
Args:
sharding_infos (List[EmbeddingShardingInfo]): The sharding infos.
Returns:
List[List[ShardedEmbeddingTable]]: The sharded embedding tables.
"""
world_size = self._world_size
tables_per_rank: List[List[ShardedEmbeddingTable]] = [
[] for i in range(world_size)
Expand All @@ -198,7 +210,7 @@ def _shard(
),
)

# expectation is planner CW shards across a node, so each CW shard will have local_size num row shards
# Expectation is planner CW shards across a node, so each CW shard will have local_size number of row shards
# pyre-fixme [6]
for i, rank in enumerate(info.param_sharding.ranks):
tables_per_rank[rank].append(
Expand All @@ -212,7 +224,6 @@ def _shard(
pooling=info.embedding_config.pooling,
is_weighted=info.embedding_config.is_weighted,
has_feature_processor=info.embedding_config.has_feature_processor,
# sharding by row and col
local_rows=shards[i].shard_sizes[0],
local_cols=shards[i].shard_sizes[1],
compute_kernel=EmbeddingComputeKernel(
Expand Down Expand Up @@ -420,7 +431,7 @@ class GridPooledEmbeddingSharding(
]
):
"""
Shards embedding bags table-wise then row-wise.
Shards embedding bags into column wise shards and shards each CW shard table wise row wise within a node
"""

def create_input_dist(
Expand Down

0 comments on commit 5e6ebb9

Please sign in to comment.