From ddcfd64619afd7412633be73b193eb61824f56bc Mon Sep 17 00:00:00 2001 From: Dennis van der Staay Date: Fri, 26 Jul 2024 21:13:28 -0700 Subject: [PATCH] Open Slots API (#2249) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2249 Adds concept of open slots to show users if just growing table or actual replacement of ids. Also fixes default input_hash_size to max int64 (2**63 - 1) Example logs when insert only: {F1774359971} vs replacement: {F1774361026} Reviewed By: iamzainhuda Differential Revision: D59931393 fbshipit-source-id: 3d46198f5e4d2bedbeaee80886b64f8e4b1817f1 --- torchrec/distributed/mc_embedding.py | 8 +++- torchrec/distributed/mc_embedding_modules.py | 5 ++- torchrec/distributed/mc_embeddingbag.py | 5 ++- torchrec/distributed/mc_modules.py | 10 ++++- torchrec/modules/mc_modules.py | 40 ++++++++++++++++++- .../tests/test_mc_embedding_modules.py | 16 ++++++++ torchrec/modules/tests/test_mc_modules.py | 6 ++- 7 files changed, 81 insertions(+), 9 deletions(-) diff --git a/torchrec/distributed/mc_embedding.py b/torchrec/distributed/mc_embedding.py index 397b05cc8..adf2e9de8 100644 --- a/torchrec/distributed/mc_embedding.py +++ b/torchrec/distributed/mc_embedding.py @@ -9,7 +9,7 @@ #!/usr/bin/env python3 -from typing import Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type import torch @@ -104,11 +104,15 @@ def __init__( self, ec_sharder: Optional[EmbeddingCollectionSharder] = None, mc_sharder: Optional[ManagedCollisionCollectionSharder] = None, + fused_params: Optional[Dict[str, Any]] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: super().__init__( ec_sharder - or EmbeddingCollectionSharder(qcomm_codecs_registry=qcomm_codecs_registry), + or EmbeddingCollectionSharder( + qcomm_codecs_registry=qcomm_codecs_registry, + fused_params=fused_params, + ), mc_sharder or ManagedCollisionCollectionSharder(), qcomm_codecs_registry=qcomm_codecs_registry, ) diff --git a/torchrec/distributed/mc_embedding_modules.py b/torchrec/distributed/mc_embedding_modules.py index 563e70fcf..f24a37f11 100644 --- a/torchrec/distributed/mc_embedding_modules.py +++ b/torchrec/distributed/mc_embedding_modules.py @@ -152,6 +152,7 @@ def input_dist( ) def _evict(self, evictions_per_table: Dict[str, Optional[torch.Tensor]]) -> None: + open_slots = None for table, evictions_indices_for_table in evictions_per_table.items(): if evictions_indices_for_table is not None: (tbe, logical_table_ids) = self._table_to_tbe_and_index[table] @@ -160,8 +161,10 @@ def _evict(self, evictions_per_table: Dict[str, Optional[torch.Tensor]]) -> None dtype=torch.long, device=self._device, ) + if open_slots is None: + open_slots = self._managed_collision_collection.open_slots() logger.info( - f"Evicting {evictions_indices_for_table.numel()} ids from {table}" + f"Table {table}: inserting {evictions_indices_for_table.numel()} ids with {open_slots[table].item()} open slots" ) with torch.no_grad(): # embeddings, and optimizer state will be reset diff --git a/torchrec/distributed/mc_embeddingbag.py b/torchrec/distributed/mc_embeddingbag.py index 47b5d6670..7cfd81743 100644 --- a/torchrec/distributed/mc_embeddingbag.py +++ b/torchrec/distributed/mc_embeddingbag.py @@ -9,7 +9,7 @@ #!/usr/bin/env python3 from dataclasses import dataclass -from typing import Dict, Optional, Type +from typing import Any, Dict, Optional, Type import torch from torchrec.distributed.embedding_types import KJTList @@ -93,12 +93,13 @@ def __init__( self, ebc_sharder: Optional[EmbeddingBagCollectionSharder] = None, mc_sharder: Optional[ManagedCollisionCollectionSharder] = None, + fused_params: Optional[Dict[str, Any]] = None, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: super().__init__( ebc_sharder or EmbeddingBagCollectionSharder( - qcomm_codecs_registry=qcomm_codecs_registry + fused_params=fused_params, qcomm_codecs_registry=qcomm_codecs_registry ), mc_sharder or ManagedCollisionCollectionSharder(), qcomm_codecs_registry=qcomm_codecs_registry, diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index aa7a2c07e..c980a9de5 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -144,7 +144,6 @@ def __init__( qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: super().__init__() - self._device = device self._env = env self._table_name_to_parameter_sharding: Dict[str, ParameterSharding] = ( @@ -482,6 +481,15 @@ def evict(self) -> Dict[str, Optional[torch.Tensor]]: evictions[table] = managed_collision_module.evict() return evictions + def open_slots(self) -> Dict[str, torch.Tensor]: + open_slots: Dict[str, torch.Tensor] = {} + for ( + table, + managed_collision_module, + ) in self._managed_collision_modules.items(): + open_slots[table] = managed_collision_module.open_slots() + return open_slots + def output_dist( self, ctx: ManagedCollisionCollectionContext, diff --git a/torchrec/modules/mc_modules.py b/torchrec/modules/mc_modules.py index c39df81c9..554a55d76 100644 --- a/torchrec/modules/mc_modules.py +++ b/torchrec/modules/mc_modules.py @@ -174,6 +174,13 @@ def input_size(self) -> int: """ pass + @abc.abstractmethod + def open_slots(self) -> torch.Tensor: + """ + Returns number of unused slots in managed collision module + """ + pass + @abc.abstractmethod def rebuild_with_output_id_range( self, @@ -265,6 +272,15 @@ def evict(self) -> Dict[str, Optional[torch.Tensor]]: evictions[table] = managed_collision_module.evict() return evictions + def open_slots(self) -> Dict[str, torch.Tensor]: + open_slots: Dict[str, torch.Tensor] = {} + for ( + table, + managed_collision_module, + ) in self._managed_collision_modules.items(): + open_slots[table] = managed_collision_module.open_slots() + return open_slots + class MCHEvictionPolicyMetadataInfo(NamedTuple): metadata_name: str @@ -516,6 +532,7 @@ def coalesce_history_metadata( additional_ids: Optional[torch.Tensor] = None, threshold_mask: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: + coalesced_history_metadata: Dict[str, torch.Tensor] = {} history_last_access_iter = history_metadata["last_access_iter"] if additional_ids is not None: @@ -792,7 +809,7 @@ def __init__( device: torch.device, eviction_policy: MCHEvictionPolicy, eviction_interval: int, - input_hash_size: int = 2**63, + input_hash_size: int = (2**63) - 1, input_hash_func: Optional[Callable[[torch.Tensor, int], torch.Tensor]] = None, mch_size: Optional[int] = None, # experimental mch_hash_func: Optional[Callable[[torch.Tensor, int], torch.Tensor]] = None, @@ -845,6 +862,22 @@ def _init_buffers(self) -> None: device=self.device, ), ) + self.register_buffer( + "_zch_slots", + torch.tensor( + [(self._zch_size - 1)], + dtype=torch.int64, + device=self.device, + ), + persistent=False, + ) + self.register_buffer( + "_delimiter", + torch.tensor( + [torch.iinfo(torch.int64).max], dtype=torch.int64, device=self.device + ), + persistent=False, + ) self.register_buffer( "_mch_remapped_ids_mapping", torch.arange(self._zch_size, dtype=torch.int64, device=self.device), @@ -1140,6 +1173,11 @@ def output_size(self) -> int: def input_size(self) -> int: return self._input_hash_size + def open_slots(self) -> torch.Tensor: + return self._zch_slots - torch.searchsorted( + self._mch_sorted_raw_ids, self._delimiter + ) + @torch.no_grad() def evict(self) -> Optional[torch.Tensor]: if self._evicted: diff --git a/torchrec/modules/tests/test_mc_embedding_modules.py b/torchrec/modules/tests/test_mc_embedding_modules.py index 63f6714bb..d1babc9a2 100644 --- a/torchrec/modules/tests/test_mc_embedding_modules.py +++ b/torchrec/modules/tests/test_mc_embedding_modules.py @@ -131,9 +131,25 @@ def test_zch_ebc_ec_train(self) -> None: ) for mc_module in mc_modules: + + self.assertEqual( + mc_module._managed_collision_collection.open_slots()["t1"].item(), + zch_size - 1, + ) # (ZCH-1 slots) + out1, remapped_kjt1 = mc_module.forward(update_one) + + self.assertEqual( + mc_module._managed_collision_collection.open_slots()["t1"].item(), + zch_size - 1, + ) # prior update, ZCH-1 slots + out2, remapped_kjt2 = mc_module.forward(update_one) + self.assertEqual( + mc_module._managed_collision_collection.open_slots()["t1"].item(), 0 + ) # post update, 0 slots + assert torch.all( # pyre-ignore[16] remapped_kjt1["f1"].values() diff --git a/torchrec/modules/tests/test_mc_modules.py b/torchrec/modules/tests/test_mc_modules.py index 6d159b844..de2a76ec5 100644 --- a/torchrec/modules/tests/test_mc_modules.py +++ b/torchrec/modules/tests/test_mc_modules.py @@ -86,7 +86,7 @@ def test_lru_eviction(self) -> None: ) } mc_module.profile(features) - + self.assertEqual(mc_module.open_slots().item(), 1) ids = [3, 4, 5] features: Dict[str, JaggedTensor] = { "f1": JaggedTensor( @@ -95,7 +95,7 @@ def test_lru_eviction(self) -> None: ) } mc_module.profile(features) - + self.assertEqual(mc_module.open_slots().item(), 0) ids = [7, 8] features: Dict[str, JaggedTensor] = { "f1": JaggedTensor( @@ -104,6 +104,7 @@ def test_lru_eviction(self) -> None: ) } mc_module.profile(features) + self.assertEqual(mc_module.open_slots().item(), 0) _mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids self.assertEqual( @@ -112,6 +113,7 @@ def test_lru_eviction(self) -> None: ) _mch_last_access_iter = mc_module._mch_last_access_iter self.assertEqual(list(_mch_last_access_iter), [2, 2, 3, 3, 3]) + self.assertEqual(mc_module.open_slots().item(), 0) def test_distance_lfu_eviction(self) -> None: mc_module = MCHManagedCollisionModule(