diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index c980a9de5..39f1cbc68 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -51,10 +51,7 @@ ShardMetadata, ) from torchrec.distributed.utils import append_prefix -from torchrec.modules.mc_modules import ( - apply_mc_method_to_jt_dict, - ManagedCollisionCollection, -) +from torchrec.modules.mc_modules import ManagedCollisionCollection from torchrec.modules.utils import construct_jagged_tensors from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor @@ -191,16 +188,20 @@ def _initialize_torch_state(self) -> None: if name not in shardable_buffers: continue + sharded_sizes = list(tensor.shape) + sharded_sizes[0] = shard_size + shard_offsets = [0] * len(sharded_sizes) + shard_offsets[0] = shard_offset + global_sizes = list(tensor.shape) + global_sizes[0] = global_size self._model_parallel_mc_buffer_name_to_sharded_tensor[name] = ( ShardedTensor._init_from_local_shards( [ Shard( tensor=tensor, metadata=ShardMetadata( - # pyre-ignore [6] - shard_offsets=[shard_offset], - # pyre-ignore [6] - shard_sizes=[shard_size], + shard_offsets=shard_offsets, + shard_sizes=sharded_sizes, placement=( f"rank:{self._env.rank}/cuda:" f"{get_local_rank(self._env.world_size, self._env.rank)}" @@ -208,8 +209,7 @@ def _initialize_torch_state(self) -> None: ), ) ], - # pyre-ignore [6] - torch.Size([global_size]), + torch.Size(global_sizes), process_group=self._env.process_group, ) ) @@ -256,9 +256,7 @@ def _create_managed_collision_modules( self, module: ManagedCollisionCollection ) -> None: - self._mc_module_name_shard_metadata: DefaultDict[ - str, DefaultDict[str, List[int]] - ] = defaultdict(lambda: defaultdict(list)) + self._mc_module_name_shard_metadata: DefaultDict[str, List[int]] = defaultdict() self._feature_to_offset: Dict[str, int] = {} for sharding in self._embedding_shardings: @@ -392,15 +390,19 @@ def input_dist( self._has_uninitialized_input_dists = False with torch.no_grad(): + features_dict = features.to_dict() + output: Dict[str, JaggedTensor] = features_dict.copy() + for table, mc_module in self._managed_collision_modules.items(): + feature_list: List[str] = self._table_to_features[table] + mc_input: Dict[str, JaggedTensor] = {} + for feature in feature_list: + mc_input[feature] = features_dict[feature] + mc_input = mc_module.preprocess(mc_input) + output.update(mc_input) + # NOTE shared features not currently supported - features = KeyedJaggedTensor.from_jt_dict( - apply_mc_method_to_jt_dict( - "preprocess", - features.to_dict(), - self._table_to_features, - self._managed_collision_modules, - ) - ) + features = KeyedJaggedTensor.from_jt_dict(output) + if self._features_order: features = features.permute( self._features_order, @@ -456,19 +458,17 @@ def compute( -1, features.stride() ) features_dict = features.to_dict() - features_dict = apply_mc_method_to_jt_dict( - "profile", - features_dict=features_dict, - table_to_features=self._table_to_features, - managed_collisions=self._managed_collision_modules, - ) - features_dict = apply_mc_method_to_jt_dict( - "remap", - features_dict=features_dict, - table_to_features=self._table_to_features, - managed_collisions=self._managed_collision_modules, - ) - remapped_kjts.append(KeyedJaggedTensor.from_jt_dict(features_dict)) + output: Dict[str, JaggedTensor] = features_dict.copy() + for table, mc_module in self._managed_collision_modules.items(): + feature_list: List[str] = self._table_to_features[table] + mc_input: Dict[str, JaggedTensor] = {} + for feature in feature_list: + mc_input[feature] = features_dict[feature] + mc_input = mc_module.profile(mc_input) + mc_input = mc_module.remap(mc_input) + output.update(mc_input) + + remapped_kjts.append(KeyedJaggedTensor.from_jt_dict(output)) return KJTList(remapped_kjts) @@ -522,6 +522,7 @@ def create_context(self) -> ManagedCollisionCollectionContext: return ManagedCollisionCollectionContext(sharding_contexts=[]) def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: + # TODO (bwen): this does not include `_hash_zch_identities` for fqn, _ in self.named_buffers(): yield append_prefix(prefix, fqn) diff --git a/torchrec/modules/mc_modules.py b/torchrec/modules/mc_modules.py index 554a55d76..46360a02e 100644 --- a/torchrec/modules/mc_modules.py +++ b/torchrec/modules/mc_modules.py @@ -10,7 +10,6 @@ #!/usr/bin/env python3 import abc -from collections import defaultdict from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union import torch @@ -30,23 +29,15 @@ @torch.fx.wrap def apply_mc_method_to_jt_dict( + mc_module: nn.Module, method: str, features_dict: Dict[str, JaggedTensor], - table_to_features: Dict[str, List[str]], - managed_collisions: nn.ModuleDict, ) -> Dict[str, JaggedTensor]: """ Applies an MC method to a dictionary of JaggedTensors, returning the updated dictionary with same ordering """ - mc_output: Dict[str, JaggedTensor] = features_dict.copy() - for table, features in table_to_features.items(): - mc_input: Dict[str, JaggedTensor] = {} - for feature in features: - mc_input[feature] = features_dict[feature] - mc_module = managed_collisions[table] - attr = getattr(mc_module, method) - mc_output.update(attr(mc_input)) - return mc_output + attr = getattr(mc_module, method) + return attr(features_dict) @torch.no_grad() @@ -153,6 +144,14 @@ def evict(self) -> Optional[torch.Tensor]: """ pass + @abc.abstractmethod + def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]: + pass + + @abc.abstractmethod + def profile(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]: + pass + @abc.abstractmethod def forward( self, @@ -203,6 +202,8 @@ class ManagedCollisionCollection(nn.Module): embedding_confgs (List[BaseEmbeddingConfig]): List of embedding configs, for each table with a managed collsion module """ + _table_to_features: Dict[str, List[str]] + def __init__( self, managed_collision_modules: Dict[str, ManagedCollisionModule], @@ -216,10 +217,13 @@ def __init__( for config in embedding_configs for feature in config.feature_names } - self._table_to_features: Dict[str, List[str]] = defaultdict(list) + self._table_to_features = {} self._compute_jt_dict_to_kjt = ComputeJTDictToKJT() for feature, table in self._feature_to_table.items(): + if table not in self._table_to_features: + self._table_to_features[table] = [] + self._table_to_features[table].append(feature) table_to_config = {config.name: config for config in embedding_configs} @@ -243,25 +247,18 @@ def forward( self, features: KeyedJaggedTensor, ) -> KeyedJaggedTensor: - features_dict = apply_mc_method_to_jt_dict( - "preprocess", - features_dict=features.to_dict(), - table_to_features=self._table_to_features, - managed_collisions=self._managed_collision_modules, - ) - features_dict = apply_mc_method_to_jt_dict( - "profile", - features_dict=features_dict, - table_to_features=self._table_to_features, - managed_collisions=self._managed_collision_modules, - ) - features_dict = apply_mc_method_to_jt_dict( - "remap", - features_dict=features_dict, - table_to_features=self._table_to_features, - managed_collisions=self._managed_collision_modules, - ) - return self._compute_jt_dict_to_kjt(features_dict) + features_dict = features.to_dict() + output: Dict[str, JaggedTensor] = features_dict.copy() + for table, mc_module in self._managed_collision_modules.items(): + feature_list: List[str] = self._table_to_features[table] + mc_input: Dict[str, JaggedTensor] = {} + for feature in feature_list: + mc_input[feature] = features_dict[feature] + mc_input = mc_module.preprocess(mc_input) + mc_input = mc_module.profile(mc_input) + mc_input = mc_module.remap(mc_input) + output.update(mc_input) + return self._compute_jt_dict_to_kjt(output) def evict(self) -> Dict[str, Optional[torch.Tensor]]: evictions: Dict[str, Optional[torch.Tensor]] = {} @@ -933,7 +930,17 @@ def _init_history_buffers(self, features: Dict[str, JaggedTensor]) -> None: self._history_metadata[metadata_name] = getattr(self, buffer_name) @torch.no_grad() - def preprocess(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]: + def preprocess( + self, + features: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + return apply_mc_method_to_jt_dict( + self, + "_preprocess", + features, + ) + + def _preprocess(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]: if self._input_hash_func is None: return features preprocessed_features: Dict[str, JaggedTensor] = {} @@ -1070,6 +1077,16 @@ def _coalesce_history(self) -> None: def profile( self, features: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + return apply_mc_method_to_jt_dict( + self, + "_profile", + features, + ) + + def _profile( + self, + features: Dict[str, JaggedTensor], ) -> Dict[str, JaggedTensor]: if not self.training: return features @@ -1115,7 +1132,17 @@ def profile( return features @torch.no_grad() - def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]: + def remap( + self, + features: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + return apply_mc_method_to_jt_dict( + self, + "_remap", + features, + ) + + def _remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]: remapped_features: Dict[str, JaggedTensor] = {} for name, feature in features.items():