Skip to content

Commit

Permalink
Clean up duplicate code in EmbeddingCollection and EmbeddingBagCollec…
Browse files Browse the repository at this point in the history
…tion (pytorch#1666)

Summary:

`EmbeddingCollection` and `EmbeddingBagCollection` contain duplicate code for methods `_pre_load_state_dict_hook()` and `_initialize_torch_state()`. Refactor this out to an abstract class `ShardedTensorEmbeddingModule` which both EC and EBC inherit

Differential Revision: D53198210
  • Loading branch information
sarckk authored and facebook-github-bot committed Jan 30, 2024
1 parent 0e153c9 commit 1375abd
Show file tree
Hide file tree
Showing 4 changed files with 484 additions and 299 deletions.
176 changes: 27 additions & 149 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import copy
import logging
import warnings
from collections import defaultdict, deque, OrderedDict
from collections import defaultdict, deque
from dataclasses import dataclass, field
from itertools import accumulate
from typing import Any, cast, Dict, List, MutableMapping, Optional, Type, Union
Expand All @@ -28,6 +28,7 @@
EmbeddingComputeKernel,
KJTList,
ShardedEmbeddingModule,
ShardedEmbeddingModuleState,
ShardingType,
)
from torchrec.distributed.sharding.cw_sequence_sharding import (
Expand Down Expand Up @@ -71,7 +72,7 @@
EmbeddingCollectionInterface,
)
from torchrec.modules.utils import construct_jagged_tensors
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
from torchrec.optim.fused import FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor

Expand Down Expand Up @@ -297,6 +298,7 @@ class ShardedEmbeddingCollection(
Dict[str, JaggedTensor],
EmbeddingCollectionContext,
],
ShardedEmbeddingModuleState,
# TODO remove after compute_kernel X sharding decoupling
FusedOptimizerModule,
):
Expand Down Expand Up @@ -421,153 +423,9 @@ def __init__(
if module.device != torch.device("meta"):
self.load_state_dict(module.state_dict())

@staticmethod
def _pre_load_state_dict_hook(
self: "ShardedEmbeddingCollection",
state_dict: Dict[str, Any],
prefix: str,
*args: Any,
) -> None:
"""
Modify the destination state_dict for model parallel
to transform from ShardedTensors into tensors
"""
for (
table_name,
model_shards,
) in self._model_parallel_name_to_local_shards.items():
key = f"{prefix}embeddings.{table_name}.weight"

# If state_dict[key] is already a ShardedTensor, use its local shards
if isinstance(state_dict[key], ShardedTensor):
local_shards = state_dict[key].local_shards()
# If no local shards, create an empty tensor
if len(local_shards) == 0:
state_dict[key] = torch.empty(0)
else:
dim = state_dict[key].metadata().shards_metadata[0].shard_sizes[1]
# CW multiple shards are merged
if len(local_shards) > 1:
state_dict[key] = torch.cat(
[s.tensor.view(-1) for s in local_shards], dim=0
).view(-1, dim)
else:
state_dict[key] = local_shards[0].tensor.view(-1, dim)
else:
local_shards = []
for shard in model_shards:
# Extract shard size and offsets for splicing
shard_sizes = shard.metadata.shard_sizes
shard_offsets = shard.metadata.shard_offsets

# Prepare tensor by splicing and placing on appropriate device
spliced_tensor = state_dict[key][
shard_offsets[0] : shard_offsets[0] + shard_sizes[0],
shard_offsets[1] : shard_offsets[1] + shard_sizes[1],
].to(shard.tensor.get_device())

# Append spliced tensor into local shards
local_shards.append(spliced_tensor)

state_dict[key] = (
torch.empty(0)
if not local_shards
else torch.cat(local_shards, dim=0)
)

def _initialize_torch_state(self) -> None: # noqa
"""
This provides consistency between this class and the EmbeddingCollection's
nn.Module API calls (state_dict, named_modules, etc)
"""

self.embeddings: nn.ModuleDict = nn.ModuleDict()
for table_name in self._table_names:
self.embeddings[table_name] = nn.Module()
self._model_parallel_name_to_local_shards = OrderedDict()
self._model_parallel_name_to_sharded_tensor = OrderedDict()
model_parallel_name_to_compute_kernel: Dict[str, str] = {}
for (
table_name,
parameter_sharding,
) in self.module_sharding_plan.items():
if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value:
continue
self._model_parallel_name_to_local_shards[table_name] = []
model_parallel_name_to_compute_kernel[
table_name
] = parameter_sharding.compute_kernel

self._name_to_table_size = {}
for table in self._embedding_configs:
self._name_to_table_size[table.name] = (
table.num_embeddings,
table.embedding_dim,
)

for sharding_type, lookup in zip(
self._sharding_type_to_sharding.keys(), self._lookups
):
if sharding_type == ShardingType.DATA_PARALLEL.value:
# unwrap DDP
lookup = lookup.module
else:
# save local_shards for transforming MP params to shardedTensor
for key, v in lookup.state_dict().items():
table_name = key[: -len(".weight")]
self._model_parallel_name_to_local_shards[table_name].extend(
v.local_shards()
)
for (
table_name,
tbe_slice,
) in lookup.named_parameters_by_table():
self.embeddings[table_name].register_parameter("weight", tbe_slice)
for (
table_name,
local_shards,
) in self._model_parallel_name_to_local_shards.items():
# for shards that don't exist on this rank, register with empty tensor
if not hasattr(self.embeddings[table_name], "weight"):
self.embeddings[table_name].register_parameter(
"weight", nn.Parameter(torch.empty(0))
)
if (
model_parallel_name_to_compute_kernel[table_name]
!= EmbeddingComputeKernel.DENSE.value
):
self.embeddings[table_name].weight._in_backward_optimizers = [
EmptyFusedOptimizer()
]
# created ShardedTensors once in init, use in post_state_dict_hook
self._model_parallel_name_to_sharded_tensor[
table_name
] = ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=self._env.process_group,
)

def post_state_dict_hook(
module: ShardedEmbeddingCollection,
destination: Dict[str, torch.Tensor],
prefix: str,
_local_metadata: Dict[str, Any],
) -> None:
# Adjust dense MP
for (
table_name,
sharded_t,
) in module._model_parallel_name_to_sharded_tensor.items():
destination_key = f"{prefix}embeddings.{table_name}.weight"
destination[destination_key] = sharded_t

self._register_state_dict_hook(post_state_dict_hook)
self._register_load_state_dict_pre_hook(
self._pre_load_state_dict_hook, with_module=True
)

self.reset_parameters()
@property
def module_weight_key(self) -> str:
return "embeddings"

def reset_parameters(self) -> None:
if self._device and self._device.type == "meta":
Expand All @@ -579,6 +437,26 @@ def reset_parameters(self) -> None:
# pyre-ignore
table_config.init_fn(param)

def _initialize_torch_state(self) -> None: # noqa
"""
Provides consistency between this class and the EmbeddingCollection's
nn.Module API calls (state_dict, named_modules, etc)
"""

# Set module dict in child class so it gets registered in self._modules
self.embeddings = self.init_embedding_modules(
self.module_sharding_plan,
self._embedding_configs,
self._sharding_type_to_sharding.keys(),
self._lookups,
self._env.process_group,
)
self._register_state_dict_hook(self.post_state_dict_hook)
self._register_load_state_dict_pre_hook(
self.pre_load_state_dict_hook, with_module=True
)
self.reset_parameters()

def _generate_permute_indices_per_feature(
self,
embedding_configs: List[EmbeddingConfig],
Expand Down
Loading

0 comments on commit 1375abd

Please sign in to comment.