Skip to content

Commit ddcfd64

Browse files
dstaay-fbfacebook-github-bot
authored andcommitted
Open Slots API (#2249)
Summary: Pull Request resolved: #2249 Adds concept of open slots to show users if just growing table or actual replacement of ids. Also fixes default input_hash_size to max int64 (2**63 - 1) Example logs when insert only: {F1774359971} vs replacement: {F1774361026} Reviewed By: iamzainhuda Differential Revision: D59931393 fbshipit-source-id: 3d46198f5e4d2bedbeaee80886b64f8e4b1817f1
1 parent c89e9df commit ddcfd64

File tree

7 files changed

+81
-9
lines changed

7 files changed

+81
-9
lines changed

torchrec/distributed/mc_embedding.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
#!/usr/bin/env python3
1111

12-
from typing import Dict, List, Optional, Type
12+
from typing import Any, Dict, List, Optional, Type
1313

1414
import torch
1515

@@ -104,11 +104,15 @@ def __init__(
104104
self,
105105
ec_sharder: Optional[EmbeddingCollectionSharder] = None,
106106
mc_sharder: Optional[ManagedCollisionCollectionSharder] = None,
107+
fused_params: Optional[Dict[str, Any]] = None,
107108
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
108109
) -> None:
109110
super().__init__(
110111
ec_sharder
111-
or EmbeddingCollectionSharder(qcomm_codecs_registry=qcomm_codecs_registry),
112+
or EmbeddingCollectionSharder(
113+
qcomm_codecs_registry=qcomm_codecs_registry,
114+
fused_params=fused_params,
115+
),
112116
mc_sharder or ManagedCollisionCollectionSharder(),
113117
qcomm_codecs_registry=qcomm_codecs_registry,
114118
)

torchrec/distributed/mc_embedding_modules.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def input_dist(
152152
)
153153

154154
def _evict(self, evictions_per_table: Dict[str, Optional[torch.Tensor]]) -> None:
155+
open_slots = None
155156
for table, evictions_indices_for_table in evictions_per_table.items():
156157
if evictions_indices_for_table is not None:
157158
(tbe, logical_table_ids) = self._table_to_tbe_and_index[table]
@@ -160,8 +161,10 @@ def _evict(self, evictions_per_table: Dict[str, Optional[torch.Tensor]]) -> None
160161
dtype=torch.long,
161162
device=self._device,
162163
)
164+
if open_slots is None:
165+
open_slots = self._managed_collision_collection.open_slots()
163166
logger.info(
164-
f"Evicting {evictions_indices_for_table.numel()} ids from {table}"
167+
f"Table {table}: inserting {evictions_indices_for_table.numel()} ids with {open_slots[table].item()} open slots"
165168
)
166169
with torch.no_grad():
167170
# embeddings, and optimizer state will be reset

torchrec/distributed/mc_embeddingbag.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#!/usr/bin/env python3
1010

1111
from dataclasses import dataclass
12-
from typing import Dict, Optional, Type
12+
from typing import Any, Dict, Optional, Type
1313

1414
import torch
1515
from torchrec.distributed.embedding_types import KJTList
@@ -93,12 +93,13 @@ def __init__(
9393
self,
9494
ebc_sharder: Optional[EmbeddingBagCollectionSharder] = None,
9595
mc_sharder: Optional[ManagedCollisionCollectionSharder] = None,
96+
fused_params: Optional[Dict[str, Any]] = None,
9697
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
9798
) -> None:
9899
super().__init__(
99100
ebc_sharder
100101
or EmbeddingBagCollectionSharder(
101-
qcomm_codecs_registry=qcomm_codecs_registry
102+
fused_params=fused_params, qcomm_codecs_registry=qcomm_codecs_registry
102103
),
103104
mc_sharder or ManagedCollisionCollectionSharder(),
104105
qcomm_codecs_registry=qcomm_codecs_registry,

torchrec/distributed/mc_modules.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ def __init__(
144144
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
145145
) -> None:
146146
super().__init__()
147-
148147
self._device = device
149148
self._env = env
150149
self._table_name_to_parameter_sharding: Dict[str, ParameterSharding] = (
@@ -482,6 +481,15 @@ def evict(self) -> Dict[str, Optional[torch.Tensor]]:
482481
evictions[table] = managed_collision_module.evict()
483482
return evictions
484483

484+
def open_slots(self) -> Dict[str, torch.Tensor]:
485+
open_slots: Dict[str, torch.Tensor] = {}
486+
for (
487+
table,
488+
managed_collision_module,
489+
) in self._managed_collision_modules.items():
490+
open_slots[table] = managed_collision_module.open_slots()
491+
return open_slots
492+
485493
def output_dist(
486494
self,
487495
ctx: ManagedCollisionCollectionContext,

torchrec/modules/mc_modules.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,13 @@ def input_size(self) -> int:
174174
"""
175175
pass
176176

177+
@abc.abstractmethod
178+
def open_slots(self) -> torch.Tensor:
179+
"""
180+
Returns number of unused slots in managed collision module
181+
"""
182+
pass
183+
177184
@abc.abstractmethod
178185
def rebuild_with_output_id_range(
179186
self,
@@ -265,6 +272,15 @@ def evict(self) -> Dict[str, Optional[torch.Tensor]]:
265272
evictions[table] = managed_collision_module.evict()
266273
return evictions
267274

275+
def open_slots(self) -> Dict[str, torch.Tensor]:
276+
open_slots: Dict[str, torch.Tensor] = {}
277+
for (
278+
table,
279+
managed_collision_module,
280+
) in self._managed_collision_modules.items():
281+
open_slots[table] = managed_collision_module.open_slots()
282+
return open_slots
283+
268284

269285
class MCHEvictionPolicyMetadataInfo(NamedTuple):
270286
metadata_name: str
@@ -516,6 +532,7 @@ def coalesce_history_metadata(
516532
additional_ids: Optional[torch.Tensor] = None,
517533
threshold_mask: Optional[torch.Tensor] = None,
518534
) -> Dict[str, torch.Tensor]:
535+
519536
coalesced_history_metadata: Dict[str, torch.Tensor] = {}
520537
history_last_access_iter = history_metadata["last_access_iter"]
521538
if additional_ids is not None:
@@ -792,7 +809,7 @@ def __init__(
792809
device: torch.device,
793810
eviction_policy: MCHEvictionPolicy,
794811
eviction_interval: int,
795-
input_hash_size: int = 2**63,
812+
input_hash_size: int = (2**63) - 1,
796813
input_hash_func: Optional[Callable[[torch.Tensor, int], torch.Tensor]] = None,
797814
mch_size: Optional[int] = None, # experimental
798815
mch_hash_func: Optional[Callable[[torch.Tensor, int], torch.Tensor]] = None,
@@ -845,6 +862,22 @@ def _init_buffers(self) -> None:
845862
device=self.device,
846863
),
847864
)
865+
self.register_buffer(
866+
"_zch_slots",
867+
torch.tensor(
868+
[(self._zch_size - 1)],
869+
dtype=torch.int64,
870+
device=self.device,
871+
),
872+
persistent=False,
873+
)
874+
self.register_buffer(
875+
"_delimiter",
876+
torch.tensor(
877+
[torch.iinfo(torch.int64).max], dtype=torch.int64, device=self.device
878+
),
879+
persistent=False,
880+
)
848881
self.register_buffer(
849882
"_mch_remapped_ids_mapping",
850883
torch.arange(self._zch_size, dtype=torch.int64, device=self.device),
@@ -1140,6 +1173,11 @@ def output_size(self) -> int:
11401173
def input_size(self) -> int:
11411174
return self._input_hash_size
11421175

1176+
def open_slots(self) -> torch.Tensor:
1177+
return self._zch_slots - torch.searchsorted(
1178+
self._mch_sorted_raw_ids, self._delimiter
1179+
)
1180+
11431181
@torch.no_grad()
11441182
def evict(self) -> Optional[torch.Tensor]:
11451183
if self._evicted:

torchrec/modules/tests/test_mc_embedding_modules.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,25 @@ def test_zch_ebc_ec_train(self) -> None:
131131
)
132132

133133
for mc_module in mc_modules:
134+
135+
self.assertEqual(
136+
mc_module._managed_collision_collection.open_slots()["t1"].item(),
137+
zch_size - 1,
138+
) # (ZCH-1 slots)
139+
134140
out1, remapped_kjt1 = mc_module.forward(update_one)
141+
142+
self.assertEqual(
143+
mc_module._managed_collision_collection.open_slots()["t1"].item(),
144+
zch_size - 1,
145+
) # prior update, ZCH-1 slots
146+
135147
out2, remapped_kjt2 = mc_module.forward(update_one)
136148

149+
self.assertEqual(
150+
mc_module._managed_collision_collection.open_slots()["t1"].item(), 0
151+
) # post update, 0 slots
152+
137153
assert torch.all(
138154
# pyre-ignore[16]
139155
remapped_kjt1["f1"].values()

torchrec/modules/tests/test_mc_modules.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_lru_eviction(self) -> None:
8686
)
8787
}
8888
mc_module.profile(features)
89-
89+
self.assertEqual(mc_module.open_slots().item(), 1)
9090
ids = [3, 4, 5]
9191
features: Dict[str, JaggedTensor] = {
9292
"f1": JaggedTensor(
@@ -95,7 +95,7 @@ def test_lru_eviction(self) -> None:
9595
)
9696
}
9797
mc_module.profile(features)
98-
98+
self.assertEqual(mc_module.open_slots().item(), 0)
9999
ids = [7, 8]
100100
features: Dict[str, JaggedTensor] = {
101101
"f1": JaggedTensor(
@@ -104,6 +104,7 @@ def test_lru_eviction(self) -> None:
104104
)
105105
}
106106
mc_module.profile(features)
107+
self.assertEqual(mc_module.open_slots().item(), 0)
107108

108109
_mch_sorted_raw_ids = mc_module._mch_sorted_raw_ids
109110
self.assertEqual(
@@ -112,6 +113,7 @@ def test_lru_eviction(self) -> None:
112113
)
113114
_mch_last_access_iter = mc_module._mch_last_access_iter
114115
self.assertEqual(list(_mch_last_access_iter), [2, 2, 3, 3, 3])
116+
self.assertEqual(mc_module.open_slots().item(), 0)
115117

116118
def test_distance_lfu_eviction(self) -> None:
117119
mc_module = MCHManagedCollisionModule(

0 commit comments

Comments
 (0)