Skip to content

Commit

Permalink
Open Slots API (#2249)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2249

Adds concept of open slots to show users if just growing table or actual replacement of ids.   Also fixes default input_hash_size to max int64 (2**63 - 1)

Example logs when insert only:

{F1774359971}

vs replacement:

{F1774361026}

Reviewed By: iamzainhuda

Differential Revision: D59931393

fbshipit-source-id: 3d46198f5e4d2bedbeaee80886b64f8e4b1817f1
  • Loading branch information
dstaay-fb authored and facebook-github-bot committed Jul 27, 2024
1 parent c89e9df commit ddcfd64
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 9 deletions.
8 changes: 6 additions & 2 deletions torchrec/distributed/mc_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

#!/usr/bin/env python3

from typing import Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Type

import torch

Expand Down Expand Up @@ -104,11 +104,15 @@ def __init__(
self,
ec_sharder: Optional[EmbeddingCollectionSharder] = None,
mc_sharder: Optional[ManagedCollisionCollectionSharder] = None,
fused_params: Optional[Dict[str, Any]] = None,
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
) -> None:
super().__init__(
ec_sharder
or EmbeddingCollectionSharder(qcomm_codecs_registry=qcomm_codecs_registry),
or EmbeddingCollectionSharder(
qcomm_codecs_registry=qcomm_codecs_registry,
fused_params=fused_params,
),
mc_sharder or ManagedCollisionCollectionSharder(),
qcomm_codecs_registry=qcomm_codecs_registry,
)
Expand Down
5 changes: 4 additions & 1 deletion torchrec/distributed/mc_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def input_dist(
)

def _evict(self, evictions_per_table: Dict[str, Optional[torch.Tensor]]) -> None:
open_slots = None
for table, evictions_indices_for_table in evictions_per_table.items():
if evictions_indices_for_table is not None:
(tbe, logical_table_ids) = self._table_to_tbe_and_index[table]
Expand All @@ -160,8 +161,10 @@ def _evict(self, evictions_per_table: Dict[str, Optional[torch.Tensor]]) -> None
dtype=torch.long,
device=self._device,
)
if open_slots is None:
open_slots = self._managed_collision_collection.open_slots()
logger.info(
f"Evicting {evictions_indices_for_table.numel()} ids from {table}"
f"Table {table}: inserting {evictions_indices_for_table.numel()} ids with {open_slots[table].item()} open slots"
)
with torch.no_grad():
# embeddings, and optimizer state will be reset
Expand Down
5 changes: 3 additions & 2 deletions torchrec/distributed/mc_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#!/usr/bin/env python3

from dataclasses import dataclass
from typing import Dict, Optional, Type
from typing import Any, Dict, Optional, Type

import torch
from torchrec.distributed.embedding_types import KJTList
Expand Down Expand Up @@ -93,12 +93,13 @@ def __init__(
self,
ebc_sharder: Optional[EmbeddingBagCollectionSharder] = None,
mc_sharder: Optional[ManagedCollisionCollectionSharder] = None,
fused_params: Optional[Dict[str, Any]] = None,
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
) -> None:
super().__init__(
ebc_sharder
or EmbeddingBagCollectionSharder(
qcomm_codecs_registry=qcomm_codecs_registry
fused_params=fused_params, qcomm_codecs_registry=qcomm_codecs_registry
),
mc_sharder or ManagedCollisionCollectionSharder(),
qcomm_codecs_registry=qcomm_codecs_registry,
Expand Down
10 changes: 9 additions & 1 deletion torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def __init__(
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
) -> None:
super().__init__()

self._device = device
self._env = env
self._table_name_to_parameter_sharding: Dict[str, ParameterSharding] = (
Expand Down Expand Up @@ -482,6 +481,15 @@ def evict(self) -> Dict[str, Optional[torch.Tensor]]:
evictions[table] = managed_collision_module.evict()
return evictions

def open_slots(self) -> Dict[str, torch.Tensor]:
open_slots: Dict[str, torch.Tensor] = {}
for (
table,
managed_collision_module,
) in self._managed_collision_modules.items():
open_slots[table] = managed_collision_module.open_slots()
return open_slots

def output_dist(
self,
ctx: ManagedCollisionCollectionContext,
Expand Down
40 changes: 39 additions & 1 deletion torchrec/modules/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,13 @@ def input_size(self) -> int:
"""
pass

@abc.abstractmethod
def open_slots(self) -> torch.Tensor:
"""
Returns number of unused slots in managed collision module
"""
pass

@abc.abstractmethod
def rebuild_with_output_id_range(
self,
Expand Down Expand Up @@ -265,6 +272,15 @@ def evict(self) -> Dict[str, Optional[torch.Tensor]]:
evictions[table] = managed_collision_module.evict()
return evictions

def open_slots(self) -> Dict[str, torch.Tensor]:
open_slots: Dict[str, torch.Tensor] = {}
for (
table,
managed_collision_module,
) in self._managed_collision_modules.items():
open_slots[table] = managed_collision_module.open_slots()
return open_slots


class MCHEvictionPolicyMetadataInfo(NamedTuple):
metadata_name: str
Expand Down Expand Up @@ -516,6 +532,7 @@ def coalesce_history_metadata(
additional_ids: Optional[torch.Tensor] = None,
threshold_mask: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:

coalesced_history_metadata: Dict[str, torch.Tensor] = {}
history_last_access_iter = history_metadata["last_access_iter"]
if additional_ids is not None:
Expand Down Expand Up @@ -792,7 +809,7 @@ def __init__(
device: torch.device,
eviction_policy: MCHEvictionPolicy,
eviction_interval: int,
input_hash_size: int = 2**63,
input_hash_size: int = (2**63) - 1,
input_hash_func: Optional[Callable[[torch.Tensor, int], torch.Tensor]] = None,
mch_size: Optional[int] = None, # experimental
mch_hash_func: Optional[Callable[[torch.Tensor, int], torch.Tensor]] = None,
Expand Down Expand Up @@ -845,6 +862,22 @@ def _init_buffers(self) -> None:
device=self.device,
),
)
self.register_buffer(
"_zch_slots",
torch.tensor(
[(self._zch_size - 1)],
dtype=torch.int64,
device=self.device,
),
persistent=False,
)
self.register_buffer(
"_delimiter",
torch.tensor(
[torch.iinfo(torch.int64).max], dtype=torch.int64, device=self.device
),
persistent=False,
)
self.register_buffer(
"_mch_remapped_ids_mapping",
torch.arange(self._zch_size, dtype=torch.int64, device=self.device),
Expand Down Expand Up @@ -1140,6 +1173,11 @@ def output_size(self) -> int:
def input_size(self) -> int:
return self._input_hash_size

def open_slots(self) -> torch.Tensor:
return self._zch_slots - torch.searchsorted(
self._mch_sorted_raw_ids, self._delimiter
)

@torch.no_grad()
def evict(self) -> Optional[torch.Tensor]:
if self._evicted:
Expand Down
16 changes: 16 additions & 0 deletions torchrec/modules/tests/test_mc_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,25 @@ def test_zch_ebc_ec_train(self) -> None:
)

for mc_module in mc_modules:

self.assertEqual(
mc_module._managed_collision_collection.open_slots()["t1"].item(),
zch_size - 1,
) # (ZCH-1 slots)

out1, remapped_kjt1 = mc_module.forward(update_one)

self.assertEqual(
mc_module._managed_collision_collection.open_slots()["t1"].item(),
zch_size - 1,
) # prior update, ZCH-1 slots

out2, remapped_kjt2 = mc_module.forward(update_one)

self.assertEqual(
mc_module._managed_collision_collection.open_slots()["t1"].item(), 0
) # post update, 0 slots

assert torch.all(
# pyre-ignore[16]
remapped_kjt1["f1"].values()
Expand Down
6 changes: 4 additions & 2 deletions torchrec/modules/tests/test_mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_lru_eviction(self) -> None:
)
}
mc_module.profile(features)

self.assertEqual(mc_module.open_slots().item(), 1)
ids = [3, 4, 5]
features: Dict[str, JaggedTensor] = {
"f1": JaggedTensor(
Expand All @@ -95,7 +95,7 @@ def test_lru_eviction(self) -> None:
)
}
mc_module.profile(features)

self.assertEqual(mc_module.open_slots().item(), 0)
ids = [7, 8]
features: Dict[str, JaggedTensor] = {
"f1": JaggedTensor(
Expand All @@ -104,6 +104,7 @@ def test_lru_eviction(self) -> None:
)
}
mc_module.profile(features)
self.assertEqual(mc_module.open_slots().item(), 0)

_mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids
self.assertEqual(
Expand All @@ -112,6 +113,7 @@ def test_lru_eviction(self) -> None:
)
_mch_last_access_iter = mc_module._mch_last_access_iter
self.assertEqual(list(_mch_last_access_iter), [2, 2, 3, 3, 3])
self.assertEqual(mc_module.open_slots().item(), 0)

def test_distance_lfu_eviction(self) -> None:
mc_module = MCHManagedCollisionModule(
Expand Down

0 comments on commit ddcfd64

Please sign in to comment.