Skip to content

Commit

Permalink
2024-10-18 nightly release (d1a2990)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Oct 18, 2024
1 parent 555d3db commit 4baf8ee
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 12 deletions.
73 changes: 61 additions & 12 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,42 @@ class ShardParams:
local_metadata: List[ShardMetadata]
embedding_weights: List[torch.Tensor]

def get_optimizer_single_value_shard_metadata_and_global_metadata(
table_global_metadata: ShardedTensorMetadata,
optimizer_state: torch.Tensor,
) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]:
table_global_shards_metadata: List[ShardMetadata] = (
table_global_metadata.shards_metadata
)

table_shard_metadata_to_optimizer_shard_metadata = {}
for offset, table_shard_metadata in enumerate(table_global_shards_metadata):
table_shard_metadata_to_optimizer_shard_metadata[
table_shard_metadata
] = ShardMetadata(
shard_sizes=[1], # single value optimizer state
shard_offsets=[offset], # offset increases by 1 for each shard
placement=table_shard_metadata.placement,
)

tensor_properties = TensorProperties(
dtype=optimizer_state.dtype,
layout=optimizer_state.layout,
requires_grad=False,
)
single_value_optimizer_st_metadata = ShardedTensorMetadata(
shards_metadata=list(
table_shard_metadata_to_optimizer_shard_metadata.values()
),
size=torch.Size([len(table_global_shards_metadata)]),
tensor_properties=tensor_properties,
)

return (
table_shard_metadata_to_optimizer_shard_metadata,
single_value_optimizer_st_metadata,
)

def get_optimizer_rowwise_shard_metadata_and_global_metadata(
table_global_metadata: ShardedTensorMetadata,
optimizer_state: torch.Tensor,
Expand Down Expand Up @@ -356,7 +392,10 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
if optimizer_states:
optimizer_state_values = tuple(optimizer_states.values())
for optimizer_state_value in optimizer_state_values:
assert table_config.local_rows == optimizer_state_value.size(0)
assert (
table_config.local_rows == optimizer_state_value.size(0)
or optimizer_state_value.nelement() == 1 # single value state
)
optimizer_states_keys_by_table[table_config.name] = list(
optimizer_states.keys()
)
Expand Down Expand Up @@ -430,34 +469,44 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
opt_state is not None for opt_state in shard_params.optimizer_states
):
# pyre-ignore
def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor:
def get_sharded_optim_state(
momentum_idx: int, state_key: str
) -> ShardedTensor:
assert momentum_idx > 0
momentum_local_shards: List[Shard] = []
optimizer_sharded_tensor_metadata: ShardedTensorMetadata

is_rowwise_optimizer_state: bool = (
# pyre-ignore
shard_params.optimizer_states[0][momentum_idx - 1].dim()
== 1
)

if is_rowwise_optimizer_state:
optim_state = shard_params.optimizer_states[0][momentum_idx - 1] # pyre-ignore[16]
if (
optim_state.nelement() == 1 and state_key != "momentum1"
): # special handling for backward compatibility, momentum1 is rowwise state for rowwise_adagrad
# single value state: one value per table
(
table_shard_metadata_to_optimizer_shard_metadata,
optimizer_sharded_tensor_metadata,
) = get_optimizer_single_value_shard_metadata_and_global_metadata(
table_config.global_metadata,
optim_state,
)
elif optim_state.dim() == 1:
# rowwise state: param.shape[0] == state.shape[0], state.shape[1] == 1
(
table_shard_metadata_to_optimizer_shard_metadata,
optimizer_sharded_tensor_metadata,
) = get_optimizer_rowwise_shard_metadata_and_global_metadata(
table_config.global_metadata,
shard_params.optimizer_states[0][momentum_idx - 1],
optim_state,
sharding_dim,
is_grid_sharded,
)
else:
# pointwise state: param.shape == state.shape
(
table_shard_metadata_to_optimizer_shard_metadata,
optimizer_sharded_tensor_metadata,
) = get_optimizer_pointwise_shard_metadata_and_global_metadata(
table_config.global_metadata,
shard_params.optimizer_states[0][momentum_idx - 1],
optim_state,
)

for optimizer_state, table_shard_local_metadata in zip(
Expand Down Expand Up @@ -499,7 +548,7 @@ def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor:
cur_state_key = optimizer_state_keys[cur_state_idx]

state[weight][f"{table_config.name}.{cur_state_key}"] = (
get_sharded_optim_state(cur_state_idx + 1)
get_sharded_optim_state(cur_state_idx + 1, cur_state_key)
)

super().__init__(params, state, [param_group])
Expand Down
5 changes: 5 additions & 0 deletions torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,8 @@ def compute(
table: JaggedTensor(
values=kjt.values(),
lengths=kjt.lengths(),
# TODO: improve this temp solution by passing real weights
weights=torch.tensor(kjt.length_per_key()),
)
}
mcm = self._managed_collision_modules[table]
Expand All @@ -660,6 +662,8 @@ def compute(
table: JaggedTensor(
values=features.values(),
lengths=features.lengths(),
# TODO: improve this temp solution by passing real weights
weights=torch.tensor(kjt.length_per_key()),
)
}
mcm = self._managed_collision_modules[table]
Expand All @@ -673,6 +677,7 @@ def compute(
keys=fns,
values=values,
lengths=features.lengths(),
# original weights instead of features splits
weights=features.weights_or_none(),
)
)
Expand Down
6 changes: 6 additions & 0 deletions torchrec/modules/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def _mcc_lazy_init(
return (features, created_feature_order, features_order)


@torch.fx.wrap
def _get_length_per_key(kjt: KeyedJaggedTensor) -> torch.Tensor:
return torch.tensor(kjt.length_per_key())


@torch.no_grad()
def dynamic_threshold_filter(
id_counts: torch.Tensor,
Expand Down Expand Up @@ -368,6 +373,7 @@ def forward(
table: JaggedTensor(
values=kjt.values(),
lengths=kjt.lengths(),
weights=_get_length_per_key(kjt),
)
}
mc_input = mc_module(mc_input)
Expand Down

0 comments on commit 4baf8ee

Please sign in to comment.