From d9bcb2f57fe3cff1eb4f63ba6f7066f5497cfeac Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Mon, 29 Jan 2024 16:27:37 -0800 Subject: [PATCH] Clean up duplicate code in EmbeddingCollection and EmbeddingBagCollection (#1666) Summary: `EmbeddingCollection` and `EmbeddingBagCollection` contain duplicate code for methods `_pre_load_state_dict_hook()` and `_initialize_torch_state()`. Refactor this out to an abstract class `ShardedTensorEmbeddingModule` which both EC and EBC inherit Differential Revision: D53198210 --- torchrec/distributed/embedding.py | 176 ++----------- torchrec/distributed/embedding_state.py | 191 ++++++++++++++ torchrec/distributed/embedding_types.py | 1 + torchrec/distributed/embeddingbag.py | 175 ++----------- .../distributed/tests/test_module_state.py | 245 ++++++++++++++++++ 5 files changed, 490 insertions(+), 298 deletions(-) create mode 100644 torchrec/distributed/embedding_state.py create mode 100644 torchrec/distributed/tests/test_module_state.py diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 322e5e96c..33ec4b519 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -9,7 +9,7 @@ import copy import logging import warnings -from collections import defaultdict, deque, OrderedDict +from collections import defaultdict, deque from dataclasses import dataclass, field from itertools import accumulate from typing import Any, cast, Dict, List, MutableMapping, Optional, Type, Union @@ -23,6 +23,7 @@ EmbeddingShardingInfo, KJTListSplitsAwaitable, ) +from torchrec.distributed.embedding_state import ShardedEmbeddingModuleState from torchrec.distributed.embedding_types import ( BaseEmbeddingSharder, EmbeddingComputeKernel, @@ -71,7 +72,7 @@ EmbeddingCollectionInterface, ) from torchrec.modules.utils import construct_jagged_tensors -from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule +from torchrec.optim.fused import FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor @@ -297,6 +298,7 @@ class ShardedEmbeddingCollection( Dict[str, JaggedTensor], EmbeddingCollectionContext, ], + ShardedEmbeddingModuleState, # TODO remove after compute_kernel X sharding decoupling FusedOptimizerModule, ): @@ -421,153 +423,9 @@ def __init__( if module.device != torch.device("meta"): self.load_state_dict(module.state_dict()) - @staticmethod - def _pre_load_state_dict_hook( - self: "ShardedEmbeddingCollection", - state_dict: Dict[str, Any], - prefix: str, - *args: Any, - ) -> None: - """ - Modify the destination state_dict for model parallel - to transform from ShardedTensors into tensors - """ - for ( - table_name, - model_shards, - ) in self._model_parallel_name_to_local_shards.items(): - key = f"{prefix}embeddings.{table_name}.weight" - - # If state_dict[key] is already a ShardedTensor, use its local shards - if isinstance(state_dict[key], ShardedTensor): - local_shards = state_dict[key].local_shards() - # If no local shards, create an empty tensor - if len(local_shards) == 0: - state_dict[key] = torch.empty(0) - else: - dim = state_dict[key].metadata().shards_metadata[0].shard_sizes[1] - # CW multiple shards are merged - if len(local_shards) > 1: - state_dict[key] = torch.cat( - [s.tensor.view(-1) for s in local_shards], dim=0 - ).view(-1, dim) - else: - state_dict[key] = local_shards[0].tensor.view(-1, dim) - else: - local_shards = [] - for shard in model_shards: - # Extract shard size and offsets for splicing - shard_sizes = shard.metadata.shard_sizes - shard_offsets = shard.metadata.shard_offsets - - # Prepare tensor by splicing and placing on appropriate device - spliced_tensor = state_dict[key][ - shard_offsets[0] : shard_offsets[0] + shard_sizes[0], - shard_offsets[1] : shard_offsets[1] + shard_sizes[1], - ].to(shard.tensor.get_device()) - - # Append spliced tensor into local shards - local_shards.append(spliced_tensor) - - state_dict[key] = ( - torch.empty(0) - if not local_shards - else torch.cat(local_shards, dim=0) - ) - - def _initialize_torch_state(self) -> None: # noqa - """ - This provides consistency between this class and the EmbeddingCollection's - nn.Module API calls (state_dict, named_modules, etc) - """ - - self.embeddings: nn.ModuleDict = nn.ModuleDict() - for table_name in self._table_names: - self.embeddings[table_name] = nn.Module() - self._model_parallel_name_to_local_shards = OrderedDict() - self._model_parallel_name_to_sharded_tensor = OrderedDict() - model_parallel_name_to_compute_kernel: Dict[str, str] = {} - for ( - table_name, - parameter_sharding, - ) in self.module_sharding_plan.items(): - if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: - continue - self._model_parallel_name_to_local_shards[table_name] = [] - model_parallel_name_to_compute_kernel[ - table_name - ] = parameter_sharding.compute_kernel - - self._name_to_table_size = {} - for table in self._embedding_configs: - self._name_to_table_size[table.name] = ( - table.num_embeddings, - table.embedding_dim, - ) - - for sharding_type, lookup in zip( - self._sharding_type_to_sharding.keys(), self._lookups - ): - if sharding_type == ShardingType.DATA_PARALLEL.value: - # unwrap DDP - lookup = lookup.module - else: - # save local_shards for transforming MP params to shardedTensor - for key, v in lookup.state_dict().items(): - table_name = key[: -len(".weight")] - self._model_parallel_name_to_local_shards[table_name].extend( - v.local_shards() - ) - for ( - table_name, - tbe_slice, - ) in lookup.named_parameters_by_table(): - self.embeddings[table_name].register_parameter("weight", tbe_slice) - for ( - table_name, - local_shards, - ) in self._model_parallel_name_to_local_shards.items(): - # for shards that don't exist on this rank, register with empty tensor - if not hasattr(self.embeddings[table_name], "weight"): - self.embeddings[table_name].register_parameter( - "weight", nn.Parameter(torch.empty(0)) - ) - if ( - model_parallel_name_to_compute_kernel[table_name] - != EmbeddingComputeKernel.DENSE.value - ): - self.embeddings[table_name].weight._in_backward_optimizers = [ - EmptyFusedOptimizer() - ] - # created ShardedTensors once in init, use in post_state_dict_hook - self._model_parallel_name_to_sharded_tensor[ - table_name - ] = ShardedTensor._init_from_local_shards( - local_shards, - self._name_to_table_size[table_name], - process_group=self._env.process_group, - ) - - def post_state_dict_hook( - module: ShardedEmbeddingCollection, - destination: Dict[str, torch.Tensor], - prefix: str, - _local_metadata: Dict[str, Any], - ) -> None: - # Adjust dense MP - for ( - table_name, - sharded_t, - ) in module._model_parallel_name_to_sharded_tensor.items(): - destination_key = f"{prefix}embeddings.{table_name}.weight" - destination[destination_key] = sharded_t - - self._register_state_dict_hook(post_state_dict_hook) - self._register_load_state_dict_pre_hook( - self._pre_load_state_dict_hook, with_module=True - ) - - self.reset_parameters() + @property + def module_weight_key(self) -> str: + return "embeddings" def reset_parameters(self) -> None: if self._device and self._device.type == "meta": @@ -579,6 +437,26 @@ def reset_parameters(self) -> None: # pyre-ignore table_config.init_fn(param) + def _initialize_torch_state(self) -> None: # noqa + """ + Provides consistency between this class and the EmbeddingCollection's + nn.Module API calls (state_dict, named_modules, etc) + """ + + # Set module dict in child class so it gets registered in self._modules + self.embeddings = self.init_embedding_modules( + self.module_sharding_plan, + self._embedding_configs, + self._sharding_type_to_sharding.keys(), + self._lookups, + self._env.process_group, + ) + self._register_state_dict_hook(self.post_state_dict_hook) + self._register_load_state_dict_pre_hook( + self.pre_load_state_dict_hook, with_module=True + ) + self.reset_parameters() + def _generate_permute_indices_per_feature( self, embedding_configs: List[EmbeddingConfig], diff --git a/torchrec/distributed/embedding_state.py b/torchrec/distributed/embedding_state.py new file mode 100644 index 000000000..aa1e070a7 --- /dev/null +++ b/torchrec/distributed/embedding_state.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +from collections import OrderedDict +from typing import Any, Dict, Iterable, List, Optional, Sequence + +import torch + +import torch.distributed as dist +from torch import nn +from torch.distributed._shard.sharded_tensor import Shard +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.types import ( + EmbeddingModuleShardingPlan, + ShardedTensor, + ShardingType, +) +from torchrec.modules.embedding_configs import BaseEmbeddingConfig +from torchrec.optim.fused import EmptyFusedOptimizer + + +class ShardedEmbeddingModuleState(abc.ABC): + _model_parallel_name_to_sharded_tensor: "OrderedDict[str, ShardedTensor]" + _model_parallel_name_to_local_shards: "OrderedDict[str, List[Shard]]" + + @abc.abstractmethod + def __init__(self) -> None: + super().__init__() + self._model_parallel_name_to_sharded_tensor = OrderedDict() + self._model_parallel_name_to_local_shards = OrderedDict() + + @abc.abstractproperty + def module_weight_key(self) -> str: + ... + + def init_embedding_modules( + self, + module_sharding_plan: EmbeddingModuleShardingPlan, + embedding_configs: Sequence[BaseEmbeddingConfig], + sharding_types: Iterable[str], + lookups: List[nn.Module], + pg: Optional[dist.ProcessGroup], + ) -> nn.ModuleDict: + model_parallel_name_to_compute_kernel: Dict[str, str] = {} + for ( + table_name, + parameter_sharding, + ) in module_sharding_plan.items(): + if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: + continue + self._model_parallel_name_to_local_shards[table_name] = [] + model_parallel_name_to_compute_kernel[ + table_name + ] = parameter_sharding.compute_kernel + + name_to_table_size = {} + embeddings = nn.ModuleDict() + + for table in embedding_configs: + embeddings[table.name] = nn.Module() + name_to_table_size[table.name] = ( + table.num_embeddings, + table.embedding_dim, + ) + + for sharding_type, lookup in zip(sharding_types, lookups): + if sharding_type == ShardingType.DATA_PARALLEL.value: + # unwrap DDP + lookup = lookup.module + else: + # save local_shards for transforming MP params to shardedTensor + for key, v in lookup.state_dict().items(): + table_name = key[: -len(".weight")] + self._model_parallel_name_to_local_shards[table_name].extend( + v.local_shards() + ) + for ( + table_name, + tbe_slice, + ) in lookup.named_parameters_by_table(): + embeddings[table_name].register_parameter("weight", tbe_slice) + + for ( + table_name, + local_shards, + ) in self._model_parallel_name_to_local_shards.items(): + # for shards that don't exist on this rank, register with empty tensor + if not hasattr(embeddings[table_name], "weight"): + embeddings[table_name].register_parameter( + "weight", nn.Parameter(torch.empty(0)) + ) + if ( + model_parallel_name_to_compute_kernel[table_name] + != EmbeddingComputeKernel.DENSE.value + ): + embeddings[table_name].weight._in_backward_optimizers = [ + EmptyFusedOptimizer() + ] + # created ShardedTensors once in init, use in post_state_dict_hook + self._model_parallel_name_to_sharded_tensor[ + table_name + ] = ShardedTensor._init_from_local_shards( + local_shards, + name_to_table_size[table_name], + process_group=pg, + ) + + return embeddings + + def construct_state_dict_key( + self: "ShardedEmbeddingModuleState", + prefix: str, + table_name: str, + ) -> str: + return f"{prefix}{self.module_weight_key}.{table_name}.weight" + + @staticmethod + def post_state_dict_hook( + self: "ShardedEmbeddingModuleState", + destination: Dict[str, torch.Tensor], + prefix: str, + _local_metadata: Dict[str, Any], + ) -> None: + # Adjust dense MP + for ( + table_name, + sharded_t, + ) in self._model_parallel_name_to_sharded_tensor.items(): + destination_key = self.construct_state_dict_key(prefix, table_name) + destination[destination_key] = sharded_t + + @staticmethod + def pre_load_state_dict_hook( + self: "ShardedEmbeddingModuleState", + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> None: + """ + Modify the destination state_dict for model parallel + to transform from ShardedTensors into tensors + """ + for ( + table_name, + model_shards, + ) in self._model_parallel_name_to_local_shards.items(): + key = self.construct_state_dict_key(prefix, table_name) + # If state_dict[key] is already a ShardedTensor, use its local shards + if isinstance(state_dict[key], ShardedTensor): + local_shards = state_dict[key].local_shards() + if len(local_shards) == 0: + state_dict[key] = torch.empty(0) + else: + dim = state_dict[key].metadata().shards_metadata[0].shard_sizes[1] + # CW multiple shards are merged + if len(local_shards) > 1: + state_dict[key] = torch.cat( + [s.tensor.view(-1) for s in local_shards], dim=0 + ).view(-1, dim) + else: + state_dict[key] = local_shards[0].tensor.view(-1, dim) + elif isinstance(state_dict[key], torch.Tensor): + local_shards = [] + for shard in model_shards: + # Extract shard size and offsets for splicing + shard_sizes = shard.metadata.shard_sizes + shard_offsets = shard.metadata.shard_offsets + + # Prepare tensor by splicing and placing on appropriate device + spliced_tensor = state_dict[key][ + shard_offsets[0] : shard_offsets[0] + shard_sizes[0], + shard_offsets[1] : shard_offsets[1] + shard_sizes[1], + ] + + # Append spliced tensor into local shards + local_shards.append(spliced_tensor) + + state_dict[key] = ( + torch.empty(0) + if not local_shards + else torch.cat(local_shards, dim=0) + ) + else: + raise RuntimeError( + f"Unexpected state_dict key type {type(state_dict[key])} found for {key}" + ) diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 34e3cdcaf..f110e532b 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -11,6 +11,7 @@ from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar import torch + from fbgemm_gpu.split_table_batched_embeddings_ops_training import EmbeddingLocation from torch import fx, nn from torch.nn.modules.module import _addindent diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index bdfcf3c41..3d0ec7992 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -33,6 +33,7 @@ KJTListSplitsAwaitable, Multistreamable, ) +from torchrec.distributed.embedding_state import ShardedEmbeddingModuleState from torchrec.distributed.embedding_types import ( BaseEmbeddingSharder, EmbeddingComputeKernel, @@ -75,7 +76,7 @@ EmbeddingBagCollection, EmbeddingBagCollectionInterface, ) -from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule +from torchrec.optim.fused import FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor @@ -390,6 +391,7 @@ class ShardedEmbeddingBagCollection( KeyedTensor, EmbeddingBagCollectionContext, ], + ShardedEmbeddingModuleState, # TODO remove after compute_kernel X sharding decoupling FusedOptimizerModule, ): @@ -515,158 +517,13 @@ def __init__( ]: self.load_state_dict(module.state_dict(), strict=False) - @staticmethod - def _pre_load_state_dict_hook( - self: "ShardedEmbeddingBagCollection", - state_dict: Dict[str, Any], - prefix: str, - *args: Any, - ) -> None: - """ - Modify the destination state_dict for model parallel - to transform from ShardedTensors into tensors - """ - for ( - table_name, - model_shards, - ) in self._model_parallel_name_to_local_shards.items(): - key = f"{prefix}embedding_bags.{table_name}.weight" - - # If state_dict[key] is already a ShardedTensor, use its local shards - if isinstance(state_dict[key], ShardedTensor): - local_shards = state_dict[key].local_shards() - if len(local_shards) == 0: - state_dict[key] = torch.empty(0) - else: - dim = state_dict[key].metadata().shards_metadata[0].shard_sizes[1] - # CW multiple shards are merged - if len(local_shards) > 1: - state_dict[key] = torch.cat( - [s.tensor.view(-1) for s in local_shards], dim=0 - ).view(-1, dim) - else: - state_dict[key] = local_shards[0].tensor.view(-1, dim) - elif isinstance(state_dict[key], torch.Tensor): - local_shards = [] - for shard in model_shards: - # Extract shard size and offsets for splicing - shard_sizes = shard.metadata.shard_sizes - shard_offsets = shard.metadata.shard_offsets - - # Prepare tensor by splicing and placing on appropriate device - spliced_tensor = state_dict[key][ - shard_offsets[0] : shard_offsets[0] + shard_sizes[0], - shard_offsets[1] : shard_offsets[1] + shard_sizes[1], - ] - - # Append spliced tensor into local shards - local_shards.append(spliced_tensor) - state_dict[key] = ( - torch.empty(0) - if not local_shards - else torch.cat(local_shards, dim=0) - ) - else: - raise RuntimeError( - f"Unexpected state_dict key type {type(state_dict[key])} found for {key}" - ) - - def _initialize_torch_state(self) -> None: # noqa - """ - This provides consistency between this class and the EmbeddingBagCollection's - nn.Module API calls (state_dict, named_modules, etc) - """ - self.embedding_bags: nn.ModuleDict = nn.ModuleDict() - for table_name in self._table_names: - self.embedding_bags[table_name] = nn.Module() - self._model_parallel_name_to_local_shards = OrderedDict() - self._model_parallel_name_to_sharded_tensor = OrderedDict() - model_parallel_name_to_compute_kernel: Dict[str, str] = {} - for ( - table_name, - parameter_sharding, - ) in self.module_sharding_plan.items(): - if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: - continue - self._model_parallel_name_to_local_shards[table_name] = [] - model_parallel_name_to_compute_kernel[ - table_name - ] = parameter_sharding.compute_kernel - - self._name_to_table_size = {} - for table in self._embedding_bag_configs: - self._name_to_table_size[table.name] = ( - table.num_embeddings, - table.embedding_dim, - ) - - for sharding_type, lookup in zip( - self._sharding_type_to_sharding.keys(), self._lookups - ): - if sharding_type == ShardingType.DATA_PARALLEL.value: - # unwrap DDP - lookup = lookup.module - else: - # save local_shards for transforming MP params to shardedTensor - for key, v in lookup.state_dict().items(): - table_name = key[: -len(".weight")] - self._model_parallel_name_to_local_shards[table_name].extend( - v.local_shards() - ) - for ( - table_name, - tbe_slice, - ) in lookup.named_parameters_by_table(): - self.embedding_bags[table_name].register_parameter("weight", tbe_slice) - for ( - table_name, - local_shards, - ) in self._model_parallel_name_to_local_shards.items(): - # for shards that don't exist on this rank, register with empty tensor - if not hasattr(self.embedding_bags[table_name], "weight"): - self.embedding_bags[table_name].register_parameter( - "weight", nn.Parameter(torch.empty(0)) - ) - if ( - model_parallel_name_to_compute_kernel[table_name] - != EmbeddingComputeKernel.DENSE.value - ): - self.embedding_bags[table_name].weight._in_backward_optimizers = [ - EmptyFusedOptimizer() - ] - # created ShardedTensors once in init, use in post_state_dict_hook - self._model_parallel_name_to_sharded_tensor[ - table_name - ] = ShardedTensor._init_from_local_shards( - local_shards, - self._name_to_table_size[table_name], - process_group=self._env.process_group, - ) - - def post_state_dict_hook( - module: ShardedEmbeddingBagCollection, - destination: Dict[str, torch.Tensor], - prefix: str, - _local_metadata: Dict[str, Any], - ) -> None: - # Adjust dense MP - for ( - table_name, - sharded_t, - ) in module._model_parallel_name_to_sharded_tensor.items(): - destination_key = f"{prefix}embedding_bags.{table_name}.weight" - destination[destination_key] = sharded_t - - self._register_state_dict_hook(post_state_dict_hook) - self._register_load_state_dict_pre_hook( - self._pre_load_state_dict_hook, with_module=True - ) - self.reset_parameters() + @property + def module_weight_key(self) -> str: + return "embedding_bags" def reset_parameters(self) -> None: if self._device and self._device.type == "meta": return - # Initialize embedding bags weights with init_fn for table_config in self._embedding_bag_configs: assert table_config.init_fn is not None @@ -674,6 +531,26 @@ def reset_parameters(self) -> None: # pyre-ignore table_config.init_fn(param) + def _initialize_torch_state(self) -> None: # noqa + """ + Provides consistency between this class and the EmbeddingBagCollection's + nn.Module API calls (state_dict, named_modules, etc) + """ + + # Set module dict in child class so it gets registered in self._modules + self.embedding_bags = self.init_embedding_modules( + self.module_sharding_plan, + self._embedding_bag_configs, + self._sharding_type_to_sharding.keys(), + self._lookups, + self._env.process_group, + ) + self._register_state_dict_hook(self.post_state_dict_hook) + self._register_load_state_dict_pre_hook( + self.pre_load_state_dict_hook, with_module=True + ) + self.reset_parameters() + def _create_input_dist( self, input_feature_names: List[str], diff --git a/torchrec/distributed/tests/test_module_state.py b/torchrec/distributed/tests/test_module_state.py new file mode 100644 index 000000000..7eafc9db5 --- /dev/null +++ b/torchrec/distributed/tests/test_module_state.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Dict, List, Optional + +import hypothesis.strategies as st +import torch +import torch.nn as nn +from hypothesis import given, settings, Verbosity +from torchrec import EmbeddingBagCollection, EmbeddingBagConfig, EmbeddingCollection +from torchrec.distributed.embedding import ShardedEmbeddingCollection +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection +from torchrec.distributed.sharding_plan import ( + column_wise, + construct_module_sharding_plan, + EmbeddingBagCollectionSharder, + EmbeddingCollectionSharder, + ParameterShardingGenerator, +) +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.types import ( + ModuleSharder, + ParameterSharding, + ShardedTensor, + ShardingEnv, +) +from torchrec.modules.embedding_configs import EmbeddingConfig +from torchrec.test_utils import skip_if_asan_class + + +@skip_if_asan_class +class ModuleStateTest(MultiProcessTestBase): + @staticmethod + def _test_ebc( + tables: List[EmbeddingBagConfig], + rank: int, + world_size: int, + backend: str, + parameter_sharding_plan: Dict[str, ParameterSharding], + sharder: ModuleSharder[nn.Module], + local_size: Optional[int] = None, + ) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + model = EmbeddingBagCollection( + tables=tables, + device=ctx.device, + ) + sharded_model = sharder.shard( + module=model, + params=parameter_sharding_plan, + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + device=ctx.device, + ) + assert isinstance(sharded_model, ShardedEmbeddingBagCollection) + + state_dict = sharded_model.state_dict() + + for state_dict_key in [ + "embedding_bags.0.weight", + "embedding_bags.1.weight", + ]: + assert ( + state_dict_key in state_dict + ), f"Expected '{state_dict_key}' in state_dict" + assert isinstance( + state_dict[state_dict_key], ShardedTensor + ), "expected state dict to contain ShardedTensor" + + # Check that embedding modules are registered as submodules + assert "embedding_bags" in sharded_model._modules + assert isinstance(sharded_model._modules["embedding_bags"], nn.ModuleDict) + + # try loading state dict + sharded_model.load_state_dict(state_dict) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + per_param_sharding=st.sampled_from( + [ + { + "0": column_wise(ranks=[0, 1]), + "1": column_wise(ranks=[1, 0]), + }, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_module_state_ebc( + self, + per_param_sharding: Dict[str, ParameterShardingGenerator], + ) -> None: + + WORLD_SIZE = 2 + EMBEDDING_DIM = 8 + NUM_EMBEDDINGS = 4 + + embedding_bag_configs = [ + EmbeddingBagConfig( + name=str(idx), + feature_names=[f"feature_{idx}"], + embedding_dim=EMBEDDING_DIM, + num_embeddings=NUM_EMBEDDINGS, + ) + for idx in per_param_sharding + ] + ebc = EmbeddingBagCollection(tables=embedding_bag_configs) + sharder = EmbeddingBagCollectionSharder() + + parameter_sharding_plan = construct_module_sharding_plan( + module=ebc, + per_param_sharding=per_param_sharding, + local_size=WORLD_SIZE, + world_size=WORLD_SIZE, + # pyre-ignore + sharder=sharder, + ) + + self._run_multi_process_test( + callable=self._test_ebc, + tables=embedding_bag_configs, + local_size=WORLD_SIZE, + world_size=WORLD_SIZE, + backend="nccl" + if (torch.cuda.is_available() and torch.cuda.device_count() >= 2) + else "gloo", + sharder=sharder, + parameter_sharding_plan=parameter_sharding_plan, + ) + + @staticmethod + def _test_ec( + tables: List[EmbeddingConfig], + rank: int, + world_size: int, + backend: str, + parameter_sharding_plan: Dict[str, ParameterSharding], + sharder: ModuleSharder[nn.Module], + local_size: Optional[int] = None, + ) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + model = EmbeddingCollection( + tables=tables, + device=ctx.device, + ) + sharded_model = sharder.shard( + module=model, + params=parameter_sharding_plan, + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + device=ctx.device, + ) + assert isinstance(sharded_model, ShardedEmbeddingCollection) + + state_dict = sharded_model.state_dict() + + for state_dict_key in [ + "embeddings.0.weight", + "embeddings.1.weight", + ]: + assert ( + state_dict_key in state_dict + ), f"Expected '{state_dict_key}' in state_dict" + assert isinstance( + state_dict[state_dict_key], ShardedTensor + ), "expected state dict to contain ShardedTensor" + + # Check that embedding modules are registered as submodules + assert "embeddings" in sharded_model._modules + assert isinstance(sharded_model._modules["embeddings"], nn.ModuleDict) + + # try loading state dict + sharded_model.load_state_dict(state_dict) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + per_param_sharding=st.sampled_from( + [ + { + "0": column_wise(ranks=[0, 1]), + "1": column_wise(ranks=[1, 0]), + }, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_module_state_ec( + self, + per_param_sharding: Dict[str, ParameterShardingGenerator], + ) -> None: + + WORLD_SIZE = 2 + EMBEDDING_DIM = 8 + NUM_EMBEDDINGS = 4 + + embedding_configs = [ + EmbeddingConfig( + name=str(idx), + feature_names=[f"feature_{idx}"], + embedding_dim=EMBEDDING_DIM, + num_embeddings=NUM_EMBEDDINGS, + ) + for idx in per_param_sharding + ] + ebc = EmbeddingCollection(tables=embedding_configs) + sharder = EmbeddingCollectionSharder() + + parameter_sharding_plan = construct_module_sharding_plan( + module=ebc, + per_param_sharding=per_param_sharding, + local_size=WORLD_SIZE, + world_size=WORLD_SIZE, + # pyre-ignore + sharder=sharder, + ) + + self._run_multi_process_test( + callable=self._test_ec, + tables=embedding_configs, + local_size=WORLD_SIZE, + world_size=WORLD_SIZE, + backend="nccl" + if (torch.cuda.is_available() and torch.cuda.device_count() >= 2) + else "gloo", + sharder=sharder, + parameter_sharding_plan=parameter_sharding_plan, + )