diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 8d17c36ce..e7ffc7644 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -373,6 +373,9 @@ def __init__( self._table_names: List[str] = [ config.name for config in self._embedding_configs ] + self._table_name_to_config: Dict[str, EmbeddingConfig] = { + config.name: config for config in self._embedding_configs + } self.module_sharding_plan: EmbeddingModuleShardingPlan = cast( EmbeddingModuleShardingPlan, { diff --git a/torchrec/distributed/mc_embedding.py b/torchrec/distributed/mc_embedding.py new file mode 100644 index 000000000..3dd7953b9 --- /dev/null +++ b/torchrec/distributed/mc_embedding.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +from dataclasses import dataclass +from typing import Dict, Optional, Type, Union + +import torch + +from torchrec.distributed.embedding import ( + EmbeddingCollectionContext, + EmbeddingCollectionSharder, + ShardedEmbeddingCollection, +) + +from torchrec.distributed.embedding_types import KJTList +from torchrec.distributed.mc_embedding_modules import ( + BaseManagedCollisionEmbeddingCollectionSharder, + BaseShardedManagedCollisionEmbeddingCollection, +) +from torchrec.distributed.mc_modules import ManagedCollisionCollectionSharder +from torchrec.distributed.types import ( + ParameterSharding, + QuantizedCommCodecs, + ShardingEnv, +) +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection + + +@dataclass +class ManagedCollisionEmbeddingCollectionContext(EmbeddingCollectionContext): + evictions_per_table: Optional[Dict[str, Optional[torch.Tensor]]] = None + remapped_kjt: Optional[KJTList] = None + + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + super().record_stream(stream) + if self.evictions_per_table: + # pyre-ignore + for value in self.evictions_per_table.values(): + if value is None: + continue + value.record_stream(stream) + if self.remapped_kjt is not None: + self.remapped_kjt.record_stream(stream) + + +class ShardedManagedCollisionEmbeddingCollection( + BaseShardedManagedCollisionEmbeddingCollection[ + ManagedCollisionEmbeddingCollectionContext + ] +): + def __init__( + self, + module: ManagedCollisionEmbeddingCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + ec_sharder: EmbeddingCollectionSharder, + mc_sharder: ManagedCollisionCollectionSharder, + # TODO - maybe we need this to manage unsharded/sharded consistency/state consistency + env: ShardingEnv, + device: torch.device, + ) -> None: + super().__init__( + module, + table_name_to_parameter_sharding, + ec_sharder, + mc_sharder, + env, + device, + ) + + # For consistency with embeddingbag + # pyre-ignore [8] + self._embedding_collection: ShardedEmbeddingCollection = self._embedding_module + + def create_context( + self, + ) -> ManagedCollisionEmbeddingCollectionContext: + return ManagedCollisionEmbeddingCollectionContext(sharding_contexts=[]) + + +class ManagedCollisionEmbeddingCollectionSharder( + BaseManagedCollisionEmbeddingCollectionSharder[ManagedCollisionEmbeddingCollection] +): + def __init__( + self, + ec_sharder: Optional[EmbeddingCollectionSharder] = None, + mc_sharder: Optional[ManagedCollisionCollectionSharder] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__( + ec_sharder + or EmbeddingCollectionSharder(qcomm_codecs_registry=qcomm_codecs_registry), + mc_sharder or ManagedCollisionCollectionSharder(), + qcomm_codecs_registry=qcomm_codecs_registry, + ) + + def shard( + self, + module: ManagedCollisionEmbeddingCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + ) -> ShardedManagedCollisionEmbeddingCollection: + + if device is None: + device = torch.device("cuda") + + return ShardedManagedCollisionEmbeddingCollection( + module, + params, + # pyre-ignore [6] + ec_sharder=self._e_sharder, + mc_sharder=self._mc_sharder, + env=env, + device=device, + ) + + @property + def module_type(self) -> Type[ManagedCollisionEmbeddingCollection]: + return ManagedCollisionEmbeddingCollection diff --git a/torchrec/distributed/mc_embedding_modules.py b/torchrec/distributed/mc_embedding_modules.py new file mode 100644 index 000000000..ff9f71e25 --- /dev/null +++ b/torchrec/distributed/mc_embedding_modules.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Dict, Iterator, List, Optional, Tuple, Type, TypeVar, Union + +import torch +from torch.autograd.profiler import record_function +from torchrec.distributed.embedding import ( + EmbeddingCollectionSharder, + ShardedEmbeddingCollection, +) + +from torchrec.distributed.embedding_types import ( + BaseEmbeddingSharder, + EmbeddingComputeKernel, + KJTList, + ShardedEmbeddingModule, +) +from torchrec.distributed.embeddingbag import ( + EmbeddingBagCollectionSharder, + ShardedEmbeddingBagCollection, +) +from torchrec.distributed.mc_modules import ( + ManagedCollisionCollectionSharder, + ShardedManagedCollisionCollection, +) +from torchrec.distributed.types import ( + Awaitable, + LazyAwaitable, + Multistreamable, + NoWait, + ParameterSharding, + QuantizedCommCodecs, + ShardingEnv, +) +from torchrec.distributed.utils import append_prefix +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.modules.mc_embedding_modules import ( + BaseManagedCollisionEmbeddingCollection, + ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + + +logger: logging.Logger = logging.getLogger(__name__) + + +ShrdCtx = TypeVar("ShrdCtx", bound=Multistreamable) + + +class BaseShardedManagedCollisionEmbeddingCollection( + ShardedEmbeddingModule[ + KJTList, + List[torch.Tensor], + Tuple[LazyAwaitable[KeyedTensor], LazyAwaitable[Optional[KeyedJaggedTensor]]], + ShrdCtx, + ] +): + def __init__( + self, + module: Union[ + ManagedCollisionEmbeddingBagCollection, ManagedCollisionEmbeddingCollection + ], + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + e_sharder: Union[EmbeddingBagCollectionSharder, EmbeddingCollectionSharder], + mc_sharder: ManagedCollisionCollectionSharder, + # TODO - maybe we need this to manage unsharded/sharded consistency/state consistency + env: ShardingEnv, + device: torch.device, + ) -> None: + super().__init__() + + self._device = device + self._env = env + + if isinstance(module, ManagedCollisionEmbeddingBagCollection): + assert isinstance(e_sharder, EmbeddingBagCollectionSharder) + assert isinstance(module._embedding_module, EmbeddingBagCollection) + self.bagged: bool = True + + self._embedding_module: ShardedEmbeddingBagCollection = e_sharder.shard( + module._embedding_module, + table_name_to_parameter_sharding, + env=env, + device=device, + ) + else: + assert isinstance(e_sharder, EmbeddingCollectionSharder) + assert isinstance(module._embedding_module, EmbeddingCollection) + self.bagged: bool = False + + self._embedding_module: ShardedEmbeddingCollection = e_sharder.shard( + module._embedding_module, + table_name_to_parameter_sharding, + env=env, + device=device, + ) + # TODO: This is a hack since _embedding_module doesn't need input + # dist, so eliminating it so all fused a2a will ignore it. + self._embedding_module._has_uninitialized_input_dist = False + self._managed_collision_collection: ShardedManagedCollisionCollection = mc_sharder.shard( + module._managed_collision_collection, + table_name_to_parameter_sharding, + env=env, + device=device, + # pyre-ignore + sharding_type_to_sharding=self._embedding_module._sharding_type_to_sharding, + ) + self._return_remapped_features: bool = module._return_remapped_features + + # pyre-ignore + self._table_to_tbe_and_index = {} + for lookup in self._embedding_module._lookups: + for emb_module in lookup._emb_modules: + for table_idx, table in enumerate(emb_module._config.embedding_tables): + self._table_to_tbe_and_index[table.name] = ( + emb_module._emb_module, + torch.tensor([table_idx], dtype=torch.int, device=self._device), + ) + self._buffer_ids: torch.Tensor = torch.tensor( + [0], device=self._device, dtype=torch.int + ) + + # pyre-ignore + def input_dist( + self, + ctx: ShrdCtx, + features: KeyedJaggedTensor, + ) -> Awaitable[Awaitable[KJTList]]: + # TODO: resolve incompatiblity with different contexts + return self._managed_collision_collection.input_dist( + # pyre-fixme [6] + ctx, + features, + ) + + def _evict(self, evictions_per_table: Dict[str, Optional[torch.Tensor]]) -> None: + for table, evictions_indices_for_table in evictions_per_table.items(): + if evictions_indices_for_table is not None: + (tbe, logical_table_ids) = self._table_to_tbe_and_index[table] + pruned_indices_offsets = torch.tensor( + [0, evictions_indices_for_table.shape[0]], + dtype=torch.long, + device=self._device, + ) + logger.info( + f"Evicting {evictions_indices_for_table.numel()} ids from {table}" + ) + with torch.no_grad(): + # embeddings, and optimizer state will be reset + tbe.reset_embedding_weight_momentum( + pruned_indices=evictions_indices_for_table.long(), + pruned_indices_offsets=pruned_indices_offsets, + logical_table_ids=logical_table_ids, + buffer_ids=self._buffer_ids, + ) + + if self.bagged: + table_weight_param = ( + self._embedding_module.embedding_bags.get_parameter( + f"{table}.weight" + ) + ) + else: + table_weight_param = ( + self._embedding_module.embeddings.get_parameter( + f"{table}.weight" + ) + ) + + init_fn = self._embedding_module._table_name_to_config[ + table + ].init_fn + + # Set evicted indices to original init_fn instead of all zeros + # pyre-ignore [29] + table_weight_param[evictions_indices_for_table] = init_fn( + table_weight_param[evictions_indices_for_table] + ) + + def compute( + self, + ctx: ShrdCtx, + dist_input: KJTList, + ) -> List[torch.Tensor]: + with record_function("## compute:mcc ##"): + remapped_kjt = self._managed_collision_collection.compute( + # pyre-fixme [6] + ctx, + dist_input, + ) + evictions_per_table = self._managed_collision_collection.evict() + + self._evict(evictions_per_table) + ctx.remapped_kjt = remapped_kjt + ctx.evictions_per_table = evictions_per_table + + # pyre-ignore + return self._embedding_module.compute(ctx, remapped_kjt) + + # pyre-ignore + def output_dist( + self, + ctx: ShrdCtx, + output: List[torch.Tensor], + ) -> Tuple[LazyAwaitable[KeyedTensor], LazyAwaitable[Optional[KeyedJaggedTensor]]]: + + # pyre-ignore [6] + ebc_awaitable = self._embedding_module.output_dist(ctx, output) + + if self._return_remapped_features: + kjt_awaitable = self._managed_collision_collection.output_dist( + # pyre-fixme [6] + ctx, + # pyre-ignore [16] + ctx.remapped_kjt, + ) + else: + kjt_awaitable = NoWait(None) + + # pyre-ignore + return ebc_awaitable, kjt_awaitable + + def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: + for fqn, _ in self.named_parameters(): + yield append_prefix(prefix, fqn) + for fqn, _ in self.named_buffers(): + yield append_prefix(prefix, fqn) + + +M = TypeVar("M", bound=BaseManagedCollisionEmbeddingCollection) + + +class BaseManagedCollisionEmbeddingCollectionSharder(BaseEmbeddingSharder[M]): + def __init__( + self, + e_sharder: Union[EmbeddingBagCollectionSharder, EmbeddingCollectionSharder], + mc_sharder: ManagedCollisionCollectionSharder, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + self._e_sharder: Union[ + EmbeddingBagCollectionSharder, EmbeddingCollectionSharder + ] = e_sharder + self._mc_sharder: ManagedCollisionCollectionSharder = mc_sharder + + def shardable_parameters( + self, module: BaseManagedCollisionEmbeddingCollection + ) -> Dict[str, torch.nn.Parameter]: + # pyre-ignore + return self._e_sharder.shardable_parameters(module._embedding_module) + + def compute_kernels( + self, + sharding_type: str, + compute_device_type: str, + ) -> List[str]: + return [EmbeddingComputeKernel.FUSED.value] + + def sharding_types(self, compute_device_type: str) -> List[str]: + return list( + set.intersection( + set(self._e_sharder.sharding_types(compute_device_type)), + set(self._mc_sharder.sharding_types(compute_device_type)), + ) + ) diff --git a/torchrec/distributed/mc_embeddingbag.py b/torchrec/distributed/mc_embeddingbag.py index b86bade30..0f98b1993 100644 --- a/torchrec/distributed/mc_embeddingbag.py +++ b/torchrec/distributed/mc_embeddingbag.py @@ -1,46 +1,32 @@ -#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import logging +#!/usr/bin/env python3 + from dataclasses import dataclass -from typing import Dict, Iterator, List, Optional, Tuple, Type +from typing import Dict, Optional, Type import torch -from torch.autograd.profiler import record_function - -from torchrec.distributed.embedding_types import ( - BaseEmbeddingSharder, - EmbeddingComputeKernel, - KJTList, - ShardedEmbeddingModule, -) +from torchrec.distributed.embedding_types import KJTList from torchrec.distributed.embeddingbag import ( EmbeddingBagCollectionContext, EmbeddingBagCollectionSharder, ShardedEmbeddingBagCollection, ) -from torchrec.distributed.mc_modules import ( - ManagedCollisionCollectionSharder, - ShardedManagedCollisionCollection, +from torchrec.distributed.mc_embedding_modules import ( + BaseManagedCollisionEmbeddingCollectionSharder, + BaseShardedManagedCollisionEmbeddingCollection, ) +from torchrec.distributed.mc_modules import ManagedCollisionCollectionSharder from torchrec.distributed.types import ( - Awaitable, - LazyAwaitable, - NoWait, ParameterSharding, QuantizedCommCodecs, ShardingEnv, ) -from torchrec.distributed.utils import append_prefix from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor - - -logger: logging.Logger = logging.getLogger(__name__) @dataclass @@ -61,11 +47,8 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None: class ShardedManagedCollisionEmbeddingBagCollection( - ShardedEmbeddingModule[ - KJTList, - List[torch.Tensor], - Tuple[LazyAwaitable[KeyedTensor], LazyAwaitable[Optional[KeyedJaggedTensor]]], - ManagedCollisionEmbeddingBagCollectionContext, + BaseShardedManagedCollisionEmbeddingCollection[ + ManagedCollisionEmbeddingBagCollectionContext ] ): def __init__( @@ -78,146 +61,31 @@ def __init__( env: ShardingEnv, device: torch.device, ) -> None: - super().__init__() - - self._device = device - self._env = env - - self._embedding_bag_collection: ShardedEmbeddingBagCollection = ( - ebc_sharder.shard( - module._embedding_bag_collection, - table_name_to_parameter_sharding, - env=env, - device=device, - ) - ) - # TODO: This is a hack since _embedding_bag_collection doesn't need input - # dist, so eliminating it so all fused a2a will ignore it. - self._embedding_bag_collection._has_uninitialized_input_dist = False - self._managed_collision_collection: ShardedManagedCollisionCollection = mc_sharder.shard( - module._managed_collision_collection, + super().__init__( + module, table_name_to_parameter_sharding, - env=env, - device=device, - sharding_type_to_sharding=self._embedding_bag_collection._sharding_type_to_sharding, + ebc_sharder, + mc_sharder, + env, + device, ) - self._return_remapped_features: bool = module._return_remapped_features - # pyre-ignore - self._table_to_tbe_and_index = {} - for lookup in self._embedding_bag_collection._lookups: - for emb_module in lookup._emb_modules: - for table_idx, table in enumerate(emb_module._config.embedding_tables): - self._table_to_tbe_and_index[table.name] = ( - emb_module._emb_module, - torch.tensor([table_idx], dtype=torch.int, device=self._device), - ) - self._buffer_ids: torch.Tensor = torch.tensor( - [0], device=self._device, dtype=torch.int - ) - - # pyre-ignore - def input_dist( - self, - ctx: ManagedCollisionEmbeddingBagCollectionContext, - features: KeyedJaggedTensor, - ) -> Awaitable[Awaitable[KJTList]]: - # TODO: resolve incompatiblity with different contexts - return self._managed_collision_collection.input_dist( - # pyre-fixme [6] - ctx, - features, + # For backwards compat, some references still to self._embedding_bag_collection + # pyre-ignore [8] + self._embedding_bag_collection: ShardedEmbeddingBagCollection = ( + self._embedding_module ) - def _evict(self, evictions_per_table: Dict[str, Optional[torch.Tensor]]) -> None: - for table, evictions_indices_for_table in evictions_per_table.items(): - if evictions_indices_for_table is not None: - (tbe, logical_table_ids) = self._table_to_tbe_and_index[table] - pruned_indices_offsets = torch.tensor( - [0, evictions_indices_for_table.shape[0]], - dtype=torch.long, - device=self._device, - ) - logger.info( - f"Evicting {evictions_indices_for_table.numel()} ids from {table}" - ) - with torch.no_grad(): - # embeddings, and optimizer state will be reset - tbe.reset_embedding_weight_momentum( - pruned_indices=evictions_indices_for_table.long(), - pruned_indices_offsets=pruned_indices_offsets, - logical_table_ids=logical_table_ids, - buffer_ids=self._buffer_ids, - ) - table_weight_param = ( - self._embedding_bag_collection.embedding_bags.get_parameter( - f"{table}.weight" - ) - ) - - init_fn = self._embedding_bag_collection._table_name_to_config[ - table - ].init_fn - - # pyre-ignore - # Set evicted indices to original init_fn instead of all zeros - table_weight_param[evictions_indices_for_table] = init_fn( - table_weight_param[evictions_indices_for_table] - ) - - def compute( + def create_context( self, - ctx: ManagedCollisionEmbeddingBagCollectionContext, - dist_input: KJTList, - ) -> List[torch.Tensor]: - with record_function("## compute:mcc ##"): - remapped_kjt = self._managed_collision_collection.compute( - # pyre-fixme [6] - ctx, - dist_input, - ) - ctx.remapped_kjt = remapped_kjt - if self.training: - evictions_per_table = self._managed_collision_collection.evict() - self._evict(evictions_per_table) - ctx.evictions_per_table = evictions_per_table - - return self._embedding_bag_collection.compute(ctx, remapped_kjt) - - # pyre-ignore - def output_dist( - self, - ctx: ManagedCollisionEmbeddingBagCollectionContext, - output: List[torch.Tensor], - ) -> Tuple[LazyAwaitable[KeyedTensor], LazyAwaitable[Optional[KeyedJaggedTensor]]]: - - ebc_awaitable = self._embedding_bag_collection.output_dist(ctx, output) - - if self._return_remapped_features: - kjt_awaitable = self._managed_collision_collection.output_dist( - # pyre-fixme [6] - ctx, - # pyre-ignore [6] - ctx.remapped_kjt, - ) - else: - kjt_awaitable = NoWait(None) - - # pyre-ignore - return ebc_awaitable, kjt_awaitable - - def create_context(self) -> ManagedCollisionEmbeddingBagCollectionContext: + ) -> ManagedCollisionEmbeddingBagCollectionContext: return ManagedCollisionEmbeddingBagCollectionContext(sharding_contexts=[]) - def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: - for fqn, _ in self.named_parameters(): - yield append_prefix(prefix, fqn) - for fqn, _ in self.named_buffers(): - yield append_prefix(prefix, fqn) - class ManagedCollisionEmbeddingBagCollectionSharder( - BaseEmbeddingSharder[ManagedCollisionEmbeddingBagCollection] + BaseManagedCollisionEmbeddingCollectionSharder[ + ManagedCollisionEmbeddingBagCollection + ] ): def __init__( self, @@ -225,12 +93,13 @@ def __init__( mc_sharder: Optional[ManagedCollisionCollectionSharder] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: - super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) - self._ebc_sharder: EmbeddingBagCollectionSharder = ( - ebc_sharder or EmbeddingBagCollectionSharder(self.qcomm_codecs_registry) - ) - self._mc_sharder: ManagedCollisionCollectionSharder = ( - mc_sharder or ManagedCollisionCollectionSharder() + super().__init__( + ebc_sharder + or EmbeddingBagCollectionSharder( + qcomm_codecs_registry=qcomm_codecs_registry + ), + mc_sharder or ManagedCollisionCollectionSharder(), + qcomm_codecs_registry=qcomm_codecs_registry, ) def shard( @@ -247,32 +116,13 @@ def shard( return ShardedManagedCollisionEmbeddingBagCollection( module, params, - ebc_sharder=self._ebc_sharder, + # pyre-ignore [6] + ebc_sharder=self._e_sharder, mc_sharder=self._mc_sharder, env=env, device=device, ) - def shardable_parameters( - self, module: ManagedCollisionEmbeddingBagCollection - ) -> Dict[str, torch.nn.Parameter]: - return self._ebc_sharder.shardable_parameters(module._embedding_bag_collection) - @property def module_type(self) -> Type[ManagedCollisionEmbeddingBagCollection]: return ManagedCollisionEmbeddingBagCollection - - def compute_kernels( - self, - sharding_type: str, - compute_device_type: str, - ) -> List[str]: - return [EmbeddingComputeKernel.FUSED.value] - - def sharding_types(self, compute_device_type: str) -> List[str]: - return list( - set.intersection( - set(self._ebc_sharder.sharding_types(compute_device_type)), - set(self._mc_sharder.sharding_types(compute_device_type)), - ) - ) diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index 7c4a96ecd..d0f96fbed 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -21,6 +21,7 @@ FeatureProcessedEmbeddingBagCollectionSharder, ) from torchrec.distributed.fused_embeddingbag import FusedEmbeddingBagCollectionSharder +from torchrec.distributed.mc_embedding import ManagedCollisionEmbeddingCollectionSharder from torchrec.distributed.mc_embeddingbag import ( ManagedCollisionEmbeddingBagCollectionSharder, ) @@ -47,6 +48,7 @@ def get_default_sharders() -> List[ModuleSharder[nn.Module]]: cast(ModuleSharder[nn.Module], QuantEmbeddingBagCollectionSharder()), cast(ModuleSharder[nn.Module], QuantEmbeddingCollectionSharder()), cast(ModuleSharder[nn.Module], ManagedCollisionEmbeddingBagCollectionSharder()), + cast(ModuleSharder[nn.Module], ManagedCollisionEmbeddingCollectionSharder()), ] diff --git a/torchrec/distributed/tests/test_mc_embedding.py b/torchrec/distributed/tests/test_mc_embedding.py new file mode 100644 index 000000000..98732a284 --- /dev/null +++ b/torchrec/distributed/tests/test_mc_embedding.py @@ -0,0 +1,557 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from torchrec.distributed.embedding import ShardedEmbeddingCollection +from torchrec.distributed.mc_embedding import ( + ManagedCollisionEmbeddingCollectionSharder, + ShardedManagedCollisionEmbeddingCollection, +) +from torchrec.distributed.mc_modules import ShardedManagedCollisionCollection +from torchrec.distributed.shard import _shard_modules + +from torchrec.distributed.sharding_plan import construct_module_sharding_plan, row_wise + +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingPlan +from torchrec.modules.embedding_configs import EmbeddingConfig +from torchrec.modules.embedding_modules import EmbeddingCollection +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection +from torchrec.modules.mc_modules import ( + DistanceLFU_EvictionPolicy, + ManagedCollisionCollection, + MCHManagedCollisionModule, +) +from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward +from torchrec.optim.rowwise_adagrad import RowWiseAdagrad +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from torchrec.test_utils import skip_if_asan_class + + +class SparseArch(nn.Module): + def __init__( + self, + tables: List[EmbeddingConfig], + device: torch.device, + return_remapped: bool = False, + mch_size: Optional[int] = None, + ) -> None: + super().__init__() + self._return_remapped = return_remapped + + def mch_hash_func(id: torch.Tensor, hash_size: int) -> torch.Tensor: + return id % hash_size + + mc_modules = {} + mc_modules["table_0"] = MCHManagedCollisionModule( + zch_size=tables[0].num_embeddings - mch_size + if mch_size + else tables[0].num_embeddings, + mch_size=mch_size, + mch_hash_func=mch_hash_func if mch_size else None, + input_hash_size=4000, + device=device, + eviction_interval=2, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + + mc_modules["table_1"] = MCHManagedCollisionModule( + zch_size=tables[1].num_embeddings - mch_size + if mch_size + else tables[1].num_embeddings, + mch_size=mch_size, + mch_hash_func=mch_hash_func if mch_size else None, + device=device, + input_hash_size=4000, + eviction_interval=2, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + + self._mc_ec: ManagedCollisionEmbeddingCollection = ManagedCollisionEmbeddingCollection( + EmbeddingCollection( + tables=tables, + device=device, + ), + ManagedCollisionCollection( + managed_collision_modules=mc_modules, + # pyre-ignore + embedding_configs=tables, + ), + return_remapped_features=self._return_remapped, + ) + + def forward( + self, kjt: KeyedJaggedTensor + ) -> Tuple[torch.Tensor, Optional[Dict[str, JaggedTensor]]]: + if self._return_remapped: + ec_out, remapped_ids_out = self._mc_ec(kjt) + else: + ec_out = self._mc_ec(kjt) + remapped_ids_out = None + + pred = torch.cat( + [ec_out[key].values() for key in ["feature_0", "feature_1"]], + dim=1, + ) + loss = pred.mean() + return loss, remapped_ids_out + + +def _test_sharding( # noqa C901 + tables: List[EmbeddingConfig], + rank: int, + world_size: int, + sharder: ModuleSharder[nn.Module], + backend: str, + local_size: Optional[int] = None, + mch_size: Optional[int] = None, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + return_remapped: bool = True + sparse_arch = SparseArch( + tables, + torch.device("meta"), + return_remapped=return_remapped, + mch_size=mch_size, + ) + + apply_optimizer_in_backward( + RowWiseAdagrad, + [ + sparse_arch._mc_ec._embedding_collection.embeddings["table_0"].weight, + sparse_arch._mc_ec._embedding_collection.embeddings["table_1"].weight, + ], + {"lr": 0.01}, + ) + module_sharding_plan = construct_module_sharding_plan( + sparse_arch._mc_ec, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=local_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + sharder=sharder, + ) + + sharded_sparse_arch = _shard_modules( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"_mc_ec": module_sharding_plan}), + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[sharder], + device=ctx.device, + ) + + assert isinstance( + sharded_sparse_arch._mc_ec, ShardedManagedCollisionEmbeddingCollection + ) + assert isinstance( + sharded_sparse_arch._mc_ec._managed_collision_collection, + ShardedManagedCollisionCollection, + ) + + +def _test_sharding_and_remapping( # noqa C901 + tables: List[EmbeddingConfig], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]], + sharder: ModuleSharder[nn.Module], + backend: str, + local_size: Optional[int] = None, + mch_size: Optional[int] = None, +) -> None: + + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input = kjt_input_per_rank[rank].to(ctx.device) + kjt_out_per_iter = [ + kjt[rank].to(ctx.device) for kjt in kjt_out_per_iter_per_rank + ] + return_remapped: bool = True + sparse_arch = SparseArch( + tables, + torch.device("meta"), + return_remapped=return_remapped, + mch_size=mch_size, + ) + + apply_optimizer_in_backward( + RowWiseAdagrad, + [ + sparse_arch._mc_ec._embedding_collection.embeddings["table_0"].weight, + sparse_arch._mc_ec._embedding_collection.embeddings["table_1"].weight, + ], + {"lr": 0.01}, + ) + module_sharding_plan = construct_module_sharding_plan( + sparse_arch._mc_ec, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=local_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + sharder=sharder, + ) + + sharded_sparse_arch = _shard_modules( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"_mc_ec": module_sharding_plan}), + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[sharder], + device=ctx.device, + ) + + assert isinstance( + sharded_sparse_arch._mc_ec, ShardedManagedCollisionEmbeddingCollection + ) + assert isinstance( + sharded_sparse_arch._mc_ec._embedding_collection, + ShardedEmbeddingCollection, + ) + assert ( + sharded_sparse_arch._mc_ec._embedding_collection._has_uninitialized_input_dist + is False + ) + assert ( + not hasattr( + sharded_sparse_arch._mc_ec._embedding_collection, "_input_dists" + ) + or len(sharded_sparse_arch._mc_ec._embedding_collection._input_dists) == 0 + ) + + assert isinstance( + sharded_sparse_arch._mc_ec._managed_collision_collection, + ShardedManagedCollisionCollection, + ) + + test_state_dict = sharded_sparse_arch.state_dict() + sharded_sparse_arch.load_state_dict(test_state_dict) + + # sharded model + # each rank gets a subbatch + loss1, remapped_ids1 = sharded_sparse_arch(kjt_input) + loss1.backward() + loss2, remapped_ids2 = sharded_sparse_arch(kjt_input) + loss2.backward() + remapped_ids = [remapped_ids1, remapped_ids2] + for key in kjt_input.keys(): + for i, kjt_out in enumerate(kjt_out_per_iter): + assert torch.equal( + remapped_ids[i][key].values(), + kjt_out[key].values(), + ), f"feature {key} on {ctx.rank} iteration {i} does not match, got {remapped_ids[i][key].values()}, expect {kjt_out[key].values()}" + + # TODO: validate embedding rows, and eviction + + +@skip_if_asan_class +class ShardedMCEmbeddingCollectionParallelTest(MultiProcessTestBase): + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_uneven_sharding(self) -> None: + WORLD_SIZE = 2 + + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=17, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=33, + ), + ] + + self._run_multi_process_test( + callable=_test_sharding, + world_size=WORLD_SIZE, + tables=embedding_config, + sharder=ManagedCollisionEmbeddingCollectionSharder(), + backend="nccl", + ) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_even_sharding(self) -> None: + WORLD_SIZE = 2 + + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=16, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=32, + ), + ] + + self._run_multi_process_test( + callable=_test_sharding, + world_size=WORLD_SIZE, + tables=embedding_config, + sharder=ManagedCollisionEmbeddingCollectionSharder(), + backend="nccl", + ) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_sharding_zch_mc_ec(self) -> None: + + WORLD_SIZE = 2 + + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=16, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=32, + ), + ] + + kjt_input_per_rank = [ # noqa + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [1000, 2000, 1001, 2000, 2001, 2002], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [ + 1000, + 1002, + 1004, + 2000, + 2002, + 2004, + ], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]] = [] + kjt_out_per_iter_per_rank.append( + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 15, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 7, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + ) + # TODO: cleanup sorting so more dedugable/logical initial fill + + kjt_out_per_iter_per_rank.append( + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 14, 4, 27, 29, 28], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 5, 6, 27, 28, 30], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + ) + + self._run_multi_process_test( + callable=_test_sharding_and_remapping, + world_size=WORLD_SIZE, + tables=embedding_config, + kjt_input_per_rank=kjt_input_per_rank, + kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, + sharder=ManagedCollisionEmbeddingCollectionSharder(), + backend="nccl", + ) + + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_sharding_zch_mch_mc_ec(self) -> None: + + WORLD_SIZE = 2 + + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=16, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=32, + ), + ] + + kjt_input_per_rank = [ # noqa + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [1000, 2000, 1001, 2000, 2001, 2002], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [ + 1000, + 1002, + 1004, + 2000, + 2002, + 2004, + ], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]] = [] + kjt_out_per_iter_per_rank.append( + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [ + 4, # 1000 % 4 + 4 + 12, # 2000 % 4 + 12 + 5, # 1001 % 4 + 4 + 28, # 2000 % 4 + 28 + 29, # 2001 % 4 + 28 + 30, # 2002 % 4 + 28 + ], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [ + 4, # 1000 % 4 + 4 + 6, # 1002 % 4 + 4 + 4, # 1004 % 4 + 4 + 28, # 2000 % 4 + 28 + 30, # 2002 % 4 + 28 + 28, # 2004 % 4 + 28 + ], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + ) + # TODO: cleanup sorting so more dedugable/logical initial fill + + kjt_out_per_iter_per_rank.append( + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [ + 0, # zch for 1000 + 10, # zch for 2000 + 1, # zch for 1001 + 23, # zch for 2000 + 25, # zch for 2001 + 24, # zch for 2002 + ], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [ + 0, # zch for 1000 + 2, # zch for 1002 + 4, # 1004 % 4 + 4 + 23, # zch for 2000 + 24, # zch for 2002 + 26, # zch for 2004 + ], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + ) + + self._run_multi_process_test( + callable=_test_sharding_and_remapping, + world_size=WORLD_SIZE, + tables=embedding_config, + mch_size=8, + kjt_input_per_rank=kjt_input_per_rank, + kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, + sharder=ManagedCollisionEmbeddingCollectionSharder(), + backend="nccl", + ) diff --git a/torchrec/distributed/tests/test_sharding_plan.py b/torchrec/distributed/tests/test_sharding_plan.py index 2ac19715e..a075b0b41 100644 --- a/torchrec/distributed/tests/test_sharding_plan.py +++ b/torchrec/distributed/tests/test_sharding_plan.py @@ -23,6 +23,7 @@ FusedEmbeddingBagCollectionSharder, get_module_to_default_sharders, ManagedCollisionEmbeddingBagCollectionSharder, + ManagedCollisionEmbeddingCollectionSharder, ParameterShardingGenerator, QuantEmbeddingBagCollectionSharder, QuantEmbeddingCollectionSharder, @@ -52,7 +53,10 @@ ) from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection from torchrec.modules.fused_embedding_modules import FusedEmbeddingBagCollection -from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection +from torchrec.modules.mc_embedding_modules import ( + ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, +) from torchrec.quant.embedding_modules import ( EmbeddingBagCollection as QuantEmbeddingBagCollection, EmbeddingCollection as QuantEmbeddingCollection, @@ -710,6 +714,7 @@ def test_module_to_default_sharders(self) -> None: QuantEmbeddingBagCollection, QuantEmbeddingCollection, ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, ], ) self.assertIsInstance( @@ -738,3 +743,8 @@ def test_module_to_default_sharders(self) -> None: default_sharder_map[ManagedCollisionEmbeddingBagCollection], ManagedCollisionEmbeddingBagCollectionSharder, ) + + self.assertIsInstance( + default_sharder_map[ManagedCollisionEmbeddingCollection], + ManagedCollisionEmbeddingCollectionSharder, + ) diff --git a/torchrec/modules/mc_embedding_modules.py b/torchrec/modules/mc_embedding_modules.py index 2ea4b9060..8b66e1e45 100644 --- a/torchrec/modules/mc_embedding_modules.py +++ b/torchrec/modules/mc_embedding_modules.py @@ -6,32 +6,34 @@ # LICENSE file in the root directory of this source tree. -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn -from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) from torchrec.modules.mc_modules import ManagedCollisionCollection -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor def evict( - evictions: Dict[str, Optional[torch.Tensor]], ebc: EmbeddingBagCollection + evictions: Dict[str, Optional[torch.Tensor]], + ebc: Union[EmbeddingBagCollection, EmbeddingCollection], ) -> None: # TODO: write function return -class ManagedCollisionEmbeddingBagCollection(nn.Module): +class BaseManagedCollisionEmbeddingCollection(nn.Module): """ - ManagedCollisionEmbeddingBagCollection represents a EmbeddingBagCollection module and a set of managed collision modules. - The inputs into the MC-EBC will first be modified by the managed collision module before being passed into the embedding bag collection. - - For details of input and output types, see EmbeddingBagCollection + BaseManagedCollisionEmbeddingCollection represents a EC/EBC module and a set of managed collision modules. + The inputs into the MC-EC/EBC will first be modified by the managed collision module before being passed into the embedding collection. Args: - embedding_bag_collection: EmbeddingBagCollection to lookup embeddings + embedding_module: EmbeddingCollection to lookup embeddings managed_collision_modules: Dict of managed collision modules return_remapped_features (bool): whether to return remapped input features in addition to embeddings @@ -40,33 +42,104 @@ class ManagedCollisionEmbeddingBagCollection(nn.Module): def __init__( self, - embedding_bag_collection: EmbeddingBagCollection, + embedding_module: Union[EmbeddingBagCollection, EmbeddingCollection], managed_collision_collection: ManagedCollisionCollection, return_remapped_features: bool = False, ) -> None: super().__init__() - self._embedding_bag_collection = embedding_bag_collection self._managed_collision_collection = managed_collision_collection self._return_remapped_features = return_remapped_features - - assert ( - self._embedding_bag_collection.embedding_bag_configs() - == self._managed_collision_collection.embedding_configs() - ), "Embedding Collection and Managed Collision Collection must contain the Embedding Configs" + self._embedding_module: Union[ + EmbeddingBagCollection, EmbeddingCollection + ] = embedding_module + + if isinstance(embedding_module, EmbeddingBagCollection): + assert ( + self._embedding_module.embedding_bag_configs() + == self._managed_collision_collection.embedding_configs() + ), "Embedding Bag Collection and Managed Collision Collection must contain the Embedding Configs" + + else: + assert ( + self._embedding_module.embedding_configs() + == self._managed_collision_collection.embedding_configs() + ), "Embedding Collection and Managed Collision Collection must contain the Embedding Configs" def forward( self, features: KeyedJaggedTensor, - ) -> Tuple[KeyedTensor, Optional[KeyedJaggedTensor]]: + ) -> Tuple[ + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] + ]: features = self._managed_collision_collection(features) - pooled_embeddings = self._embedding_bag_collection(features) + embedding_res = self._embedding_module(features) - evict( - self._managed_collision_collection.evict(), self._embedding_bag_collection - ) + evict(self._managed_collision_collection.evict(), self._embedding_module) if not self._return_remapped_features: - return pooled_embeddings, None - return pooled_embeddings, features + return embedding_res, None + return embedding_res, features + + +class ManagedCollisionEmbeddingCollection(BaseManagedCollisionEmbeddingCollection): + """ + ManagedCollisionEmbeddingCollection represents a EmbeddingCollection module and a set of managed collision modules. + The inputs into the MC-EC will first be modified by the managed collision module before being passed into the embedding collection. + + For details of input and output types, see EmbeddingCollection + + Args: + embedding_module: EmbeddingCollection to lookup embeddings + managed_collision_modules: Dict of managed collision modules + return_remapped_features (bool): whether to return remapped input features + in addition to embeddings + + """ + + def __init__( + self, + embedding_collection: EmbeddingCollection, + managed_collision_collection: ManagedCollisionCollection, + return_remapped_features: bool = False, + ) -> None: + super().__init__( + embedding_collection, managed_collision_collection, return_remapped_features + ) + + # For consistency with embedding bag collection + self._embedding_collection: EmbeddingCollection = embedding_collection + + +class ManagedCollisionEmbeddingBagCollection(BaseManagedCollisionEmbeddingCollection): + """ + ManagedCollisionEmbeddingBagCollection represents a EmbeddingBagCollection module and a set of managed collision modules. + The inputs into the MC-EBC will first be modified by the managed collision module before being passed into the embedding bag collection. + + For details of input and output types, see EmbeddingBagCollection + + Args: + embedding_module: EmbeddingBagCollection to lookup embeddings + managed_collision_modules: Dict of managed collision modules + return_remapped_features (bool): whether to return remapped input features + in addition to embeddings + + """ + + def __init__( + self, + embedding_bag_collection: EmbeddingBagCollection, + managed_collision_collection: ManagedCollisionCollection, + return_remapped_features: bool = False, + ) -> None: + super().__init__( + embedding_bag_collection, + managed_collision_collection, + return_remapped_features, + ) + + # For backwards compat, as references existed in tests + self._embedding_bag_collection: EmbeddingBagCollection = ( + embedding_bag_collection + ) diff --git a/torchrec/modules/tests/test_mc_embedding_modules.py b/torchrec/modules/tests/test_mc_embedding_modules.py index a566b7649..ad372bac3 100644 --- a/torchrec/modules/tests/test_mc_embedding_modules.py +++ b/torchrec/modules/tests/test_mc_embedding_modules.py @@ -6,19 +6,26 @@ # LICENSE file in the root directory of this source tree. import unittest -from typing import cast, List, Optional +from copy import deepcopy +from typing import cast, Dict, List, Optional import torch -from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.modules.mc_embedding_modules import ( + ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, +) from torchrec.modules.mc_modules import ( DistanceLFU_EvictionPolicy, ManagedCollisionCollection, ManagedCollisionModule, MCHManagedCollisionModule, ) -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor class Tracer(torch.fx.Tracer): @@ -38,13 +45,13 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool class MCHManagedCollisionEmbeddingBagCollectionTest(unittest.TestCase): - def test_zch_ebc_train(self) -> None: + def test_zch_ebc_ec_train(self) -> None: device = torch.device("cpu") zch_size = 20 update_interval = 2 update_size = 10 - embedding_configs = [ + embedding_bag_configs = [ EmbeddingBagConfig( name="t1", embedding_dim=8, @@ -52,10 +59,24 @@ def test_zch_ebc_train(self) -> None: feature_names=["f1", "f2"], ), ] + embedding_configs = [ + EmbeddingConfig( + name="t1", + embedding_dim=8, + num_embeddings=zch_size, + feature_names=["f1", "f2"], + ), + ] ebc = EmbeddingBagCollection( + tables=embedding_bag_configs, + device=device, + ) + + ec = EmbeddingCollection( tables=embedding_configs, device=device, ) + mc_modules = { "t1": cast( ManagedCollisionModule, @@ -67,16 +88,29 @@ def test_zch_ebc_train(self) -> None: ), ), } - mcc = ManagedCollisionCollection( + mcc_ebc = ManagedCollisionCollection( managed_collision_modules=mc_modules, # pyre-ignore[6] + embedding_configs=embedding_bag_configs, + ) + + mcc_ec = ManagedCollisionCollection( + managed_collision_modules=deepcopy(mc_modules), + # pyre-ignore[6] embedding_configs=embedding_configs, ) mc_ebc = ManagedCollisionEmbeddingBagCollection( ebc, - mcc, + mcc_ebc, return_remapped_features=True, ) + mc_ec = ManagedCollisionEmbeddingCollection( + ec, + mcc_ec, + return_remapped_features=True, + ) + + mc_modules = [mc_ebc, mc_ec] update_one = KeyedJaggedTensor.from_lengths_sync( keys=["f1", "f2"], @@ -93,79 +127,109 @@ def test_zch_ebc_train(self) -> None: lengths=torch.ones((2 * update_size,), dtype=torch.int64), weights=None, ) - _, remapped_kjt1 = mc_ebc.forward(update_one) - _, remapped_kjt2 = mc_ebc.forward(update_one) - assert torch.all( - # pyre-ignore[16] - remapped_kjt1["f1"].values() - == zch_size - 1 - ), "all remapped ids should be mapped to end of range" - assert torch.all( - remapped_kjt1["f2"].values() == zch_size - 1 - ), "all remapped ids should be mapped to end of range" - - assert torch.all( - remapped_kjt2["f1"].values() == torch.arange(0, 10, dtype=torch.int64) - ) - assert torch.all( - remapped_kjt2["f2"].values() - == torch.cat( - [ - torch.arange(10, 19, dtype=torch.int64), - torch.tensor([zch_size - 1], dtype=torch.int64), # empty value - ] + + for mc_module in mc_modules: + out1, remapped_kjt1 = mc_module.forward(update_one) + out2, remapped_kjt2 = mc_module.forward(update_one) + + assert torch.all( + # pyre-ignore[16] + remapped_kjt1["f1"].values() + == zch_size - 1 + ), "all remapped ids should be mapped to end of range" + assert torch.all( + remapped_kjt1["f2"].values() == zch_size - 1 + ), "all remapped ids should be mapped to end of range" + + assert torch.all( + remapped_kjt2["f1"].values() == torch.arange(0, 10, dtype=torch.int64) ) - ) - update_two = KeyedJaggedTensor.from_lengths_sync( - keys=["f1", "f2"], - values=torch.concat( - [ - torch.arange(2000, 2000 + update_size, dtype=torch.int64), - torch.arange( - 1000 + update_size, - 1000 + 2 * update_size, - dtype=torch.int64, - ), - ] - ), - lengths=torch.ones((2 * update_size,), dtype=torch.int64), - weights=None, - ) - _, remapped_kjt3 = mc_ebc.forward(update_two) - _, remapped_kjt4 = mc_ebc.forward(update_two) + assert torch.all( + remapped_kjt2["f2"].values() + == torch.cat( + [ + torch.arange(10, 19, dtype=torch.int64), + torch.tensor([zch_size - 1], dtype=torch.int64), # empty value + ] + ) + ) + + if isinstance(mc_module, ManagedCollisionEmbeddingCollection): + self.assertTrue(isinstance(out1, Dict)) + self.assertTrue(isinstance(out2, Dict)) + self.assertEqual(out1["f1"].values().size(), (update_size, 8)) + self.assertEqual(out2["f2"].values().size(), (update_size, 8)) + else: + self.assertTrue(isinstance(out1, KeyedTensor)) + self.assertTrue(isinstance(out2, KeyedTensor)) + self.assertEqual(out1["f1"].size(), (update_size, 8)) + self.assertEqual(out2["f2"].size(), (update_size, 8)) + + update_two = KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.concat( + [ + torch.arange(2000, 2000 + update_size, dtype=torch.int64), + torch.arange( + 1000 + update_size, + 1000 + 2 * update_size, + dtype=torch.int64, + ), + ] + ), + lengths=torch.ones((2 * update_size,), dtype=torch.int64), + weights=None, + ) + out3, remapped_kjt3 = mc_module.forward(update_two) + out4, remapped_kjt4 = mc_module.forward(update_two) - assert torch.all( - remapped_kjt3["f1"].values() == zch_size - 1 - ), "all remapped ids should be mapped to end of range" + assert torch.all( + remapped_kjt3["f1"].values() == zch_size - 1 + ), "all remapped ids should be mapped to end of range" - assert torch.all(remapped_kjt3["f2"].values() == remapped_kjt2["f2"].values()) + assert torch.all( + remapped_kjt3["f2"].values() == remapped_kjt2["f2"].values() + ) - assert torch.all( - remapped_kjt4["f1"].values() - == torch.cat( - [ - torch.arange(1, 10, dtype=torch.int64), - torch.tensor([zch_size - 1], dtype=torch.int64), # empty value - ] + assert torch.all( + remapped_kjt4["f1"].values() + == torch.cat( + [ + torch.arange(1, 10, dtype=torch.int64), + torch.tensor([zch_size - 1], dtype=torch.int64), # empty value + ] + ) ) - ) - assert torch.all( - remapped_kjt4["f2"].values() - == torch.cat( - [ - torch.arange(10, 19, dtype=torch.int64), - torch.tensor([0], dtype=torch.int64), # assigned first open slot - ] + assert torch.all( + remapped_kjt4["f2"].values() + == torch.cat( + [ + torch.arange(10, 19, dtype=torch.int64), + torch.tensor( + [0], dtype=torch.int64 + ), # assigned first open slot + ] + ) ) - ) - def test_zch_ebc_eval(self) -> None: + if isinstance(mc_module, ManagedCollisionEmbeddingCollection): + self.assertTrue(isinstance(out3, Dict)) + self.assertTrue(isinstance(out4, Dict)) + self.assertEqual(out3["f1"].values().size(), (update_size, 8)) + self.assertEqual(out4["f2"].values().size(), (update_size, 8)) + else: + self.assertTrue(isinstance(out3, KeyedTensor)) + self.assertTrue(isinstance(out4, KeyedTensor)) + self.assertEqual(out3["f1"].size(), (update_size, 8)) + self.assertEqual(out4["f2"].size(), (update_size, 8)) + + def test_zch_ebc_ec_eval(self) -> None: device = torch.device("cpu") zch_size = 20 update_interval = 2 update_size = 10 - embedding_configs = [ + embedding_bag_configs = [ EmbeddingBagConfig( name="t1", embedding_dim=8, @@ -173,7 +237,19 @@ def test_zch_ebc_eval(self) -> None: feature_names=["f1", "f2"], ), ] + embedding_configs = [ + EmbeddingConfig( + name="t1", + embedding_dim=8, + num_embeddings=zch_size, + feature_names=["f1", "f2"], + ), + ] ebc = EmbeddingBagCollection( + tables=embedding_bag_configs, + device=device, + ) + ec = EmbeddingCollection( tables=embedding_configs, device=device, ) @@ -188,89 +264,107 @@ def test_zch_ebc_eval(self) -> None: ), ), } - mcc = ManagedCollisionCollection( + mcc_ebc = ManagedCollisionCollection( managed_collision_modules=mc_modules, # pyre-ignore[6] + embedding_configs=embedding_bag_configs, + ) + + mcc_ec = ManagedCollisionCollection( + managed_collision_modules=deepcopy(mc_modules), + # pyre-ignore[6] embedding_configs=embedding_configs, ) mc_ebc = ManagedCollisionEmbeddingBagCollection( ebc, - mcc, + mcc_ebc, return_remapped_features=True, ) - - update_one = KeyedJaggedTensor.from_lengths_sync( - keys=["f1", "f2"], - values=torch.concat( - [ - torch.arange(1000, 1000 + update_size, dtype=torch.int64), - torch.arange( - 1000 + update_size, - 1000 + 2 * update_size, - dtype=torch.int64, - ), - ] - ), - lengths=torch.ones((2 * update_size,), dtype=torch.int64), - weights=None, - ) - _, remapped_kjt1 = mc_ebc.forward(update_one) - _, remapped_kjt2 = mc_ebc.forward(update_one) - - assert torch.all( - # pyre-ignore[16] - remapped_kjt1["f1"].values() - == zch_size - 1 - ), "all remapped ids should be mapped to end of range" - assert torch.all( - remapped_kjt1["f2"].values() == zch_size - 1 - ), "all remapped ids should be mapped to end of range" - - assert torch.all( - remapped_kjt2["f1"].values() == torch.arange(0, 10, dtype=torch.int64) - ) - assert torch.all( - remapped_kjt2["f2"].values() - == torch.cat( - [ - torch.arange(10, 19, dtype=torch.int64), - torch.tensor([zch_size - 1], dtype=torch.int64), # empty value - ] - ) + mc_ec = ManagedCollisionEmbeddingCollection( + ec, + mcc_ec, + return_remapped_features=True, ) - # Trigger eval mode, zch should not update - mc_ebc.eval() + mc_modules = [mc_ebc, mc_ec] + + for mc_module in mc_modules: + update_one = KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.concat( + [ + torch.arange(1000, 1000 + update_size, dtype=torch.int64), + torch.arange( + 1000 + update_size, + 1000 + 2 * update_size, + dtype=torch.int64, + ), + ] + ), + lengths=torch.ones((2 * update_size,), dtype=torch.int64), + weights=None, + ) + _, remapped_kjt1 = mc_module.forward(update_one) + _, remapped_kjt2 = mc_module.forward(update_one) + + assert torch.all( + # pyre-ignore[16] + remapped_kjt1["f1"].values() + == zch_size - 1 + ), "all remapped ids should be mapped to end of range" + assert torch.all( + remapped_kjt1["f2"].values() == zch_size - 1 + ), "all remapped ids should be mapped to end of range" + + assert torch.all( + remapped_kjt2["f1"].values() == torch.arange(0, 10, dtype=torch.int64) + ) + assert torch.all( + remapped_kjt2["f2"].values() + == torch.cat( + [ + torch.arange(10, 19, dtype=torch.int64), + torch.tensor([zch_size - 1], dtype=torch.int64), # empty value + ] + ) + ) - update_two = KeyedJaggedTensor.from_lengths_sync( - keys=["f1", "f2"], - values=torch.concat( - [ - torch.arange(2000, 2000 + update_size, dtype=torch.int64), - torch.arange( - 1000 + update_size, - 1000 + 2 * update_size, - dtype=torch.int64, - ), - ] - ), - lengths=torch.ones((2 * update_size,), dtype=torch.int64), - weights=None, - ) - _, remapped_kjt3 = mc_ebc.forward(update_two) - _, remapped_kjt4 = mc_ebc.forward(update_two) + # Trigger eval mode, zch should not update + mc_module.eval() + + update_two = KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.concat( + [ + torch.arange(2000, 2000 + update_size, dtype=torch.int64), + torch.arange( + 1000 + update_size, + 1000 + 2 * update_size, + dtype=torch.int64, + ), + ] + ), + lengths=torch.ones((2 * update_size,), dtype=torch.int64), + weights=None, + ) + _, remapped_kjt3 = mc_module.forward(update_two) + _, remapped_kjt4 = mc_module.forward(update_two) - assert torch.all( - remapped_kjt3["f1"].values() == zch_size - 1 - ), "all remapped ids should be mapped to end of range" + assert torch.all( + remapped_kjt3["f1"].values() == zch_size - 1 + ), "all remapped ids should be mapped to end of range" - assert torch.all(remapped_kjt3["f2"].values() == remapped_kjt2["f2"].values()) + assert torch.all( + remapped_kjt3["f2"].values() == remapped_kjt2["f2"].values() + ) - assert torch.all( - remapped_kjt4["f1"].values() == zch_size - 1 - ), "all remapped ids should be mapped to end of range" + assert torch.all( + remapped_kjt4["f1"].values() == zch_size - 1 + ), "all remapped ids should be mapped to end of range" - assert torch.all(remapped_kjt4["f2"].values() == remapped_kjt2["f2"].values()) + assert torch.all( + remapped_kjt4["f2"].values() == remapped_kjt2["f2"].values() + ) def test_mc_collection_traceable(self) -> None: device = torch.device("cpu") @@ -315,14 +409,14 @@ def test_mc_collection_traceable(self) -> None: # TODO: since this is unsharded module, also check torch.jit.script - def test_mch_ebc(self) -> None: + def test_mch_ebc_ec(self) -> None: device = torch.device("cpu") zch_size = 10 mch_size = 10 update_interval = 2 update_size = 10 - embedding_configs = [ + embedding_bag_configs = [ EmbeddingBagConfig( name="t1", embedding_dim=8, @@ -330,7 +424,20 @@ def test_mch_ebc(self) -> None: feature_names=["f1", "f2"], ), ] + embedding_configs = [ + EmbeddingConfig( + name="t1", + embedding_dim=8, + num_embeddings=zch_size + mch_size, + feature_names=["f1", "f2"], + ), + ] + ebc = EmbeddingBagCollection( + tables=embedding_bag_configs, + device=device, + ) + ec = EmbeddingCollection( tables=embedding_configs, device=device, ) @@ -351,100 +458,114 @@ def preprocess_func(id: torch.Tensor, hash_size: int) -> torch.Tensor: ), ), } - mcc = ManagedCollisionCollection( + mcc_ec = ManagedCollisionCollection( managed_collision_modules=mc_modules, # pyre-ignore[6] embedding_configs=embedding_configs, ) + mcc_ebc = ManagedCollisionCollection( + managed_collision_modules=deepcopy(mc_modules), + # pyre-ignore[6] + embedding_configs=embedding_bag_configs, + ) mc_ebc = ManagedCollisionEmbeddingBagCollection( ebc, - mcc, + mcc_ebc, return_remapped_features=True, ) - - update_one = KeyedJaggedTensor.from_lengths_sync( - keys=["f1", "f2"], - values=torch.concat( - [ - torch.arange(1000, 1000 + update_size, dtype=torch.int64), - torch.arange( - 1000 + update_size, - 1000 + 2 * update_size, - dtype=torch.int64, - ), - ] - ), - lengths=torch.ones((2 * update_size,), dtype=torch.int64), - weights=None, + mc_ec = ManagedCollisionEmbeddingCollection( + ec, + mcc_ec, + return_remapped_features=True, ) - _, remapped_kjt1 = mc_ebc.forward(update_one) - _, remapped_kjt2 = mc_ebc.forward(update_one) - - assert torch.all( - # pyre-ignore[16] - remapped_kjt1["f1"].values() - == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) - ), "all remapped ids are in mch section" - assert torch.all( - remapped_kjt1["f2"].values() - == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) - ), "all remapped ids are in mch section" - - assert torch.all( - remapped_kjt2["f1"].values() - == torch.cat( - [ - torch.arange(0, 9, dtype=torch.int64), - torch.tensor([19], dtype=torch.int64), # % MCH for last value - ] + mc_modules = [mc_ebc, mc_ec] + + for mc_module in mc_modules: + update_one = KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.concat( + [ + torch.arange(1000, 1000 + update_size, dtype=torch.int64), + torch.arange( + 1000 + update_size, + 1000 + 2 * update_size, + dtype=torch.int64, + ), + ] + ), + lengths=torch.ones((2 * update_size,), dtype=torch.int64), + weights=None, + ) + _, remapped_kjt1 = mc_module.forward(update_one) + _, remapped_kjt2 = mc_module.forward(update_one) + + assert torch.all( + # pyre-ignore[16] + remapped_kjt1["f1"].values() + == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) + ), "all remapped ids are in mch section" + assert torch.all( + remapped_kjt1["f2"].values() + == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) + ), "all remapped ids are in mch section" + + assert torch.all( + remapped_kjt2["f1"].values() + == torch.cat( + [ + torch.arange(0, 9, dtype=torch.int64), + torch.tensor([19], dtype=torch.int64), # % MCH for last value + ] + ) ) - ) - - assert torch.all( - remapped_kjt2["f2"].values() - == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) - ), "all remapped ids are in mch section" - - update_two = KeyedJaggedTensor.from_lengths_sync( - keys=["f1", "f2"], - values=torch.concat( - [ - torch.arange(2000, 2000 + update_size, dtype=torch.int64), - torch.arange( - 1000 + update_size, - 1000 + 2 * update_size, - dtype=torch.int64, - ), - ] - ), - lengths=torch.ones((2 * update_size,), dtype=torch.int64), - weights=None, - ) - - _, remapped_kjt3 = mc_ebc.forward(update_two) - _, remapped_kjt4 = mc_ebc.forward(update_two) - - assert torch.all( - remapped_kjt3["f1"].values() - == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) - ), "all remapped ids are in mch section" - - assert torch.all( - remapped_kjt3["f2"].values() - == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) - ), "all remapped ids are in mch section" - assert torch.all( - remapped_kjt4["f1"].values() - == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) - ), "all remapped ids are in mch section" + assert torch.all( + remapped_kjt2["f2"].values() + == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) + ), "all remapped ids are in mch section" + + update_two = KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.concat( + [ + torch.arange(2000, 2000 + update_size, dtype=torch.int64), + torch.arange( + 1000 + update_size, + 1000 + 2 * update_size, + dtype=torch.int64, + ), + ] + ), + lengths=torch.ones((2 * update_size,), dtype=torch.int64), + weights=None, + ) - assert torch.all( - remapped_kjt4["f2"].values() - == torch.cat( - [ - torch.arange(0, 9, dtype=torch.int64), - torch.tensor([19], dtype=torch.int64), # assigned first open slot - ] + _, remapped_kjt3 = mc_module.forward(update_two) + _, remapped_kjt4 = mc_module.forward(update_two) + + assert torch.all( + remapped_kjt3["f1"].values() + == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) + ), "all remapped ids are in mch section" + + assert torch.all( + remapped_kjt3["f2"].values() + == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) + ), "all remapped ids are in mch section" + + assert torch.all( + remapped_kjt4["f1"].values() + == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) + ), "all remapped ids are in mch section" + + assert torch.all( + remapped_kjt4["f2"].values() + == torch.cat( + [ + torch.arange(0, 9, dtype=torch.int64), + torch.tensor( + [19], dtype=torch.int64 + ), # assigned first open slot + ] + ) ) - )