diff --git a/torchrec/modules/mc_embedding_modules.py b/torchrec/modules/mc_embedding_modules.py index 2ea4b9060..8b66e1e45 100644 --- a/torchrec/modules/mc_embedding_modules.py +++ b/torchrec/modules/mc_embedding_modules.py @@ -6,32 +6,34 @@ # LICENSE file in the root directory of this source tree. -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn -from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) from torchrec.modules.mc_modules import ManagedCollisionCollection -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor def evict( - evictions: Dict[str, Optional[torch.Tensor]], ebc: EmbeddingBagCollection + evictions: Dict[str, Optional[torch.Tensor]], + ebc: Union[EmbeddingBagCollection, EmbeddingCollection], ) -> None: # TODO: write function return -class ManagedCollisionEmbeddingBagCollection(nn.Module): +class BaseManagedCollisionEmbeddingCollection(nn.Module): """ - ManagedCollisionEmbeddingBagCollection represents a EmbeddingBagCollection module and a set of managed collision modules. - The inputs into the MC-EBC will first be modified by the managed collision module before being passed into the embedding bag collection. - - For details of input and output types, see EmbeddingBagCollection + BaseManagedCollisionEmbeddingCollection represents a EC/EBC module and a set of managed collision modules. + The inputs into the MC-EC/EBC will first be modified by the managed collision module before being passed into the embedding collection. Args: - embedding_bag_collection: EmbeddingBagCollection to lookup embeddings + embedding_module: EmbeddingCollection to lookup embeddings managed_collision_modules: Dict of managed collision modules return_remapped_features (bool): whether to return remapped input features in addition to embeddings @@ -40,33 +42,104 @@ class ManagedCollisionEmbeddingBagCollection(nn.Module): def __init__( self, - embedding_bag_collection: EmbeddingBagCollection, + embedding_module: Union[EmbeddingBagCollection, EmbeddingCollection], managed_collision_collection: ManagedCollisionCollection, return_remapped_features: bool = False, ) -> None: super().__init__() - self._embedding_bag_collection = embedding_bag_collection self._managed_collision_collection = managed_collision_collection self._return_remapped_features = return_remapped_features - - assert ( - self._embedding_bag_collection.embedding_bag_configs() - == self._managed_collision_collection.embedding_configs() - ), "Embedding Collection and Managed Collision Collection must contain the Embedding Configs" + self._embedding_module: Union[ + EmbeddingBagCollection, EmbeddingCollection + ] = embedding_module + + if isinstance(embedding_module, EmbeddingBagCollection): + assert ( + self._embedding_module.embedding_bag_configs() + == self._managed_collision_collection.embedding_configs() + ), "Embedding Bag Collection and Managed Collision Collection must contain the Embedding Configs" + + else: + assert ( + self._embedding_module.embedding_configs() + == self._managed_collision_collection.embedding_configs() + ), "Embedding Collection and Managed Collision Collection must contain the Embedding Configs" def forward( self, features: KeyedJaggedTensor, - ) -> Tuple[KeyedTensor, Optional[KeyedJaggedTensor]]: + ) -> Tuple[ + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] + ]: features = self._managed_collision_collection(features) - pooled_embeddings = self._embedding_bag_collection(features) + embedding_res = self._embedding_module(features) - evict( - self._managed_collision_collection.evict(), self._embedding_bag_collection - ) + evict(self._managed_collision_collection.evict(), self._embedding_module) if not self._return_remapped_features: - return pooled_embeddings, None - return pooled_embeddings, features + return embedding_res, None + return embedding_res, features + + +class ManagedCollisionEmbeddingCollection(BaseManagedCollisionEmbeddingCollection): + """ + ManagedCollisionEmbeddingCollection represents a EmbeddingCollection module and a set of managed collision modules. + The inputs into the MC-EC will first be modified by the managed collision module before being passed into the embedding collection. + + For details of input and output types, see EmbeddingCollection + + Args: + embedding_module: EmbeddingCollection to lookup embeddings + managed_collision_modules: Dict of managed collision modules + return_remapped_features (bool): whether to return remapped input features + in addition to embeddings + + """ + + def __init__( + self, + embedding_collection: EmbeddingCollection, + managed_collision_collection: ManagedCollisionCollection, + return_remapped_features: bool = False, + ) -> None: + super().__init__( + embedding_collection, managed_collision_collection, return_remapped_features + ) + + # For consistency with embedding bag collection + self._embedding_collection: EmbeddingCollection = embedding_collection + + +class ManagedCollisionEmbeddingBagCollection(BaseManagedCollisionEmbeddingCollection): + """ + ManagedCollisionEmbeddingBagCollection represents a EmbeddingBagCollection module and a set of managed collision modules. + The inputs into the MC-EBC will first be modified by the managed collision module before being passed into the embedding bag collection. + + For details of input and output types, see EmbeddingBagCollection + + Args: + embedding_module: EmbeddingBagCollection to lookup embeddings + managed_collision_modules: Dict of managed collision modules + return_remapped_features (bool): whether to return remapped input features + in addition to embeddings + + """ + + def __init__( + self, + embedding_bag_collection: EmbeddingBagCollection, + managed_collision_collection: ManagedCollisionCollection, + return_remapped_features: bool = False, + ) -> None: + super().__init__( + embedding_bag_collection, + managed_collision_collection, + return_remapped_features, + ) + + # For backwards compat, as references existed in tests + self._embedding_bag_collection: EmbeddingBagCollection = ( + embedding_bag_collection + ) diff --git a/torchrec/modules/tests/test_mc_embedding_modules.py b/torchrec/modules/tests/test_mc_embedding_modules.py index a566b7649..ad372bac3 100644 --- a/torchrec/modules/tests/test_mc_embedding_modules.py +++ b/torchrec/modules/tests/test_mc_embedding_modules.py @@ -6,19 +6,26 @@ # LICENSE file in the root directory of this source tree. import unittest -from typing import cast, List, Optional +from copy import deepcopy +from typing import cast, Dict, List, Optional import torch -from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.modules.mc_embedding_modules import ( + ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, +) from torchrec.modules.mc_modules import ( DistanceLFU_EvictionPolicy, ManagedCollisionCollection, ManagedCollisionModule, MCHManagedCollisionModule, ) -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor class Tracer(torch.fx.Tracer): @@ -38,13 +45,13 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool class MCHManagedCollisionEmbeddingBagCollectionTest(unittest.TestCase): - def test_zch_ebc_train(self) -> None: + def test_zch_ebc_ec_train(self) -> None: device = torch.device("cpu") zch_size = 20 update_interval = 2 update_size = 10 - embedding_configs = [ + embedding_bag_configs = [ EmbeddingBagConfig( name="t1", embedding_dim=8, @@ -52,10 +59,24 @@ def test_zch_ebc_train(self) -> None: feature_names=["f1", "f2"], ), ] + embedding_configs = [ + EmbeddingConfig( + name="t1", + embedding_dim=8, + num_embeddings=zch_size, + feature_names=["f1", "f2"], + ), + ] ebc = EmbeddingBagCollection( + tables=embedding_bag_configs, + device=device, + ) + + ec = EmbeddingCollection( tables=embedding_configs, device=device, ) + mc_modules = { "t1": cast( ManagedCollisionModule, @@ -67,16 +88,29 @@ def test_zch_ebc_train(self) -> None: ), ), } - mcc = ManagedCollisionCollection( + mcc_ebc = ManagedCollisionCollection( managed_collision_modules=mc_modules, # pyre-ignore[6] + embedding_configs=embedding_bag_configs, + ) + + mcc_ec = ManagedCollisionCollection( + managed_collision_modules=deepcopy(mc_modules), + # pyre-ignore[6] embedding_configs=embedding_configs, ) mc_ebc = ManagedCollisionEmbeddingBagCollection( ebc, - mcc, + mcc_ebc, return_remapped_features=True, ) + mc_ec = ManagedCollisionEmbeddingCollection( + ec, + mcc_ec, + return_remapped_features=True, + ) + + mc_modules = [mc_ebc, mc_ec] update_one = KeyedJaggedTensor.from_lengths_sync( keys=["f1", "f2"], @@ -93,79 +127,109 @@ def test_zch_ebc_train(self) -> None: lengths=torch.ones((2 * update_size,), dtype=torch.int64), weights=None, ) - _, remapped_kjt1 = mc_ebc.forward(update_one) - _, remapped_kjt2 = mc_ebc.forward(update_one) - assert torch.all( - # pyre-ignore[16] - remapped_kjt1["f1"].values() - == zch_size - 1 - ), "all remapped ids should be mapped to end of range" - assert torch.all( - remapped_kjt1["f2"].values() == zch_size - 1 - ), "all remapped ids should be mapped to end of range" - - assert torch.all( - remapped_kjt2["f1"].values() == torch.arange(0, 10, dtype=torch.int64) - ) - assert torch.all( - remapped_kjt2["f2"].values() - == torch.cat( - [ - torch.arange(10, 19, dtype=torch.int64), - torch.tensor([zch_size - 1], dtype=torch.int64), # empty value - ] + + for mc_module in mc_modules: + out1, remapped_kjt1 = mc_module.forward(update_one) + out2, remapped_kjt2 = mc_module.forward(update_one) + + assert torch.all( + # pyre-ignore[16] + remapped_kjt1["f1"].values() + == zch_size - 1 + ), "all remapped ids should be mapped to end of range" + assert torch.all( + remapped_kjt1["f2"].values() == zch_size - 1 + ), "all remapped ids should be mapped to end of range" + + assert torch.all( + remapped_kjt2["f1"].values() == torch.arange(0, 10, dtype=torch.int64) ) - ) - update_two = KeyedJaggedTensor.from_lengths_sync( - keys=["f1", "f2"], - values=torch.concat( - [ - torch.arange(2000, 2000 + update_size, dtype=torch.int64), - torch.arange( - 1000 + update_size, - 1000 + 2 * update_size, - dtype=torch.int64, - ), - ] - ), - lengths=torch.ones((2 * update_size,), dtype=torch.int64), - weights=None, - ) - _, remapped_kjt3 = mc_ebc.forward(update_two) - _, remapped_kjt4 = mc_ebc.forward(update_two) + assert torch.all( + remapped_kjt2["f2"].values() + == torch.cat( + [ + torch.arange(10, 19, dtype=torch.int64), + torch.tensor([zch_size - 1], dtype=torch.int64), # empty value + ] + ) + ) + + if isinstance(mc_module, ManagedCollisionEmbeddingCollection): + self.assertTrue(isinstance(out1, Dict)) + self.assertTrue(isinstance(out2, Dict)) + self.assertEqual(out1["f1"].values().size(), (update_size, 8)) + self.assertEqual(out2["f2"].values().size(), (update_size, 8)) + else: + self.assertTrue(isinstance(out1, KeyedTensor)) + self.assertTrue(isinstance(out2, KeyedTensor)) + self.assertEqual(out1["f1"].size(), (update_size, 8)) + self.assertEqual(out2["f2"].size(), (update_size, 8)) + + update_two = KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.concat( + [ + torch.arange(2000, 2000 + update_size, dtype=torch.int64), + torch.arange( + 1000 + update_size, + 1000 + 2 * update_size, + dtype=torch.int64, + ), + ] + ), + lengths=torch.ones((2 * update_size,), dtype=torch.int64), + weights=None, + ) + out3, remapped_kjt3 = mc_module.forward(update_two) + out4, remapped_kjt4 = mc_module.forward(update_two) - assert torch.all( - remapped_kjt3["f1"].values() == zch_size - 1 - ), "all remapped ids should be mapped to end of range" + assert torch.all( + remapped_kjt3["f1"].values() == zch_size - 1 + ), "all remapped ids should be mapped to end of range" - assert torch.all(remapped_kjt3["f2"].values() == remapped_kjt2["f2"].values()) + assert torch.all( + remapped_kjt3["f2"].values() == remapped_kjt2["f2"].values() + ) - assert torch.all( - remapped_kjt4["f1"].values() - == torch.cat( - [ - torch.arange(1, 10, dtype=torch.int64), - torch.tensor([zch_size - 1], dtype=torch.int64), # empty value - ] + assert torch.all( + remapped_kjt4["f1"].values() + == torch.cat( + [ + torch.arange(1, 10, dtype=torch.int64), + torch.tensor([zch_size - 1], dtype=torch.int64), # empty value + ] + ) ) - ) - assert torch.all( - remapped_kjt4["f2"].values() - == torch.cat( - [ - torch.arange(10, 19, dtype=torch.int64), - torch.tensor([0], dtype=torch.int64), # assigned first open slot - ] + assert torch.all( + remapped_kjt4["f2"].values() + == torch.cat( + [ + torch.arange(10, 19, dtype=torch.int64), + torch.tensor( + [0], dtype=torch.int64 + ), # assigned first open slot + ] + ) ) - ) - def test_zch_ebc_eval(self) -> None: + if isinstance(mc_module, ManagedCollisionEmbeddingCollection): + self.assertTrue(isinstance(out3, Dict)) + self.assertTrue(isinstance(out4, Dict)) + self.assertEqual(out3["f1"].values().size(), (update_size, 8)) + self.assertEqual(out4["f2"].values().size(), (update_size, 8)) + else: + self.assertTrue(isinstance(out3, KeyedTensor)) + self.assertTrue(isinstance(out4, KeyedTensor)) + self.assertEqual(out3["f1"].size(), (update_size, 8)) + self.assertEqual(out4["f2"].size(), (update_size, 8)) + + def test_zch_ebc_ec_eval(self) -> None: device = torch.device("cpu") zch_size = 20 update_interval = 2 update_size = 10 - embedding_configs = [ + embedding_bag_configs = [ EmbeddingBagConfig( name="t1", embedding_dim=8, @@ -173,7 +237,19 @@ def test_zch_ebc_eval(self) -> None: feature_names=["f1", "f2"], ), ] + embedding_configs = [ + EmbeddingConfig( + name="t1", + embedding_dim=8, + num_embeddings=zch_size, + feature_names=["f1", "f2"], + ), + ] ebc = EmbeddingBagCollection( + tables=embedding_bag_configs, + device=device, + ) + ec = EmbeddingCollection( tables=embedding_configs, device=device, ) @@ -188,89 +264,107 @@ def test_zch_ebc_eval(self) -> None: ), ), } - mcc = ManagedCollisionCollection( + mcc_ebc = ManagedCollisionCollection( managed_collision_modules=mc_modules, # pyre-ignore[6] + embedding_configs=embedding_bag_configs, + ) + + mcc_ec = ManagedCollisionCollection( + managed_collision_modules=deepcopy(mc_modules), + # pyre-ignore[6] embedding_configs=embedding_configs, ) mc_ebc = ManagedCollisionEmbeddingBagCollection( ebc, - mcc, + mcc_ebc, return_remapped_features=True, ) - - update_one = KeyedJaggedTensor.from_lengths_sync( - keys=["f1", "f2"], - values=torch.concat( - [ - torch.arange(1000, 1000 + update_size, dtype=torch.int64), - torch.arange( - 1000 + update_size, - 1000 + 2 * update_size, - dtype=torch.int64, - ), - ] - ), - lengths=torch.ones((2 * update_size,), dtype=torch.int64), - weights=None, - ) - _, remapped_kjt1 = mc_ebc.forward(update_one) - _, remapped_kjt2 = mc_ebc.forward(update_one) - - assert torch.all( - # pyre-ignore[16] - remapped_kjt1["f1"].values() - == zch_size - 1 - ), "all remapped ids should be mapped to end of range" - assert torch.all( - remapped_kjt1["f2"].values() == zch_size - 1 - ), "all remapped ids should be mapped to end of range" - - assert torch.all( - remapped_kjt2["f1"].values() == torch.arange(0, 10, dtype=torch.int64) - ) - assert torch.all( - remapped_kjt2["f2"].values() - == torch.cat( - [ - torch.arange(10, 19, dtype=torch.int64), - torch.tensor([zch_size - 1], dtype=torch.int64), # empty value - ] - ) + mc_ec = ManagedCollisionEmbeddingCollection( + ec, + mcc_ec, + return_remapped_features=True, ) - # Trigger eval mode, zch should not update - mc_ebc.eval() + mc_modules = [mc_ebc, mc_ec] + + for mc_module in mc_modules: + update_one = KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.concat( + [ + torch.arange(1000, 1000 + update_size, dtype=torch.int64), + torch.arange( + 1000 + update_size, + 1000 + 2 * update_size, + dtype=torch.int64, + ), + ] + ), + lengths=torch.ones((2 * update_size,), dtype=torch.int64), + weights=None, + ) + _, remapped_kjt1 = mc_module.forward(update_one) + _, remapped_kjt2 = mc_module.forward(update_one) + + assert torch.all( + # pyre-ignore[16] + remapped_kjt1["f1"].values() + == zch_size - 1 + ), "all remapped ids should be mapped to end of range" + assert torch.all( + remapped_kjt1["f2"].values() == zch_size - 1 + ), "all remapped ids should be mapped to end of range" + + assert torch.all( + remapped_kjt2["f1"].values() == torch.arange(0, 10, dtype=torch.int64) + ) + assert torch.all( + remapped_kjt2["f2"].values() + == torch.cat( + [ + torch.arange(10, 19, dtype=torch.int64), + torch.tensor([zch_size - 1], dtype=torch.int64), # empty value + ] + ) + ) - update_two = KeyedJaggedTensor.from_lengths_sync( - keys=["f1", "f2"], - values=torch.concat( - [ - torch.arange(2000, 2000 + update_size, dtype=torch.int64), - torch.arange( - 1000 + update_size, - 1000 + 2 * update_size, - dtype=torch.int64, - ), - ] - ), - lengths=torch.ones((2 * update_size,), dtype=torch.int64), - weights=None, - ) - _, remapped_kjt3 = mc_ebc.forward(update_two) - _, remapped_kjt4 = mc_ebc.forward(update_two) + # Trigger eval mode, zch should not update + mc_module.eval() + + update_two = KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.concat( + [ + torch.arange(2000, 2000 + update_size, dtype=torch.int64), + torch.arange( + 1000 + update_size, + 1000 + 2 * update_size, + dtype=torch.int64, + ), + ] + ), + lengths=torch.ones((2 * update_size,), dtype=torch.int64), + weights=None, + ) + _, remapped_kjt3 = mc_module.forward(update_two) + _, remapped_kjt4 = mc_module.forward(update_two) - assert torch.all( - remapped_kjt3["f1"].values() == zch_size - 1 - ), "all remapped ids should be mapped to end of range" + assert torch.all( + remapped_kjt3["f1"].values() == zch_size - 1 + ), "all remapped ids should be mapped to end of range" - assert torch.all(remapped_kjt3["f2"].values() == remapped_kjt2["f2"].values()) + assert torch.all( + remapped_kjt3["f2"].values() == remapped_kjt2["f2"].values() + ) - assert torch.all( - remapped_kjt4["f1"].values() == zch_size - 1 - ), "all remapped ids should be mapped to end of range" + assert torch.all( + remapped_kjt4["f1"].values() == zch_size - 1 + ), "all remapped ids should be mapped to end of range" - assert torch.all(remapped_kjt4["f2"].values() == remapped_kjt2["f2"].values()) + assert torch.all( + remapped_kjt4["f2"].values() == remapped_kjt2["f2"].values() + ) def test_mc_collection_traceable(self) -> None: device = torch.device("cpu") @@ -315,14 +409,14 @@ def test_mc_collection_traceable(self) -> None: # TODO: since this is unsharded module, also check torch.jit.script - def test_mch_ebc(self) -> None: + def test_mch_ebc_ec(self) -> None: device = torch.device("cpu") zch_size = 10 mch_size = 10 update_interval = 2 update_size = 10 - embedding_configs = [ + embedding_bag_configs = [ EmbeddingBagConfig( name="t1", embedding_dim=8, @@ -330,7 +424,20 @@ def test_mch_ebc(self) -> None: feature_names=["f1", "f2"], ), ] + embedding_configs = [ + EmbeddingConfig( + name="t1", + embedding_dim=8, + num_embeddings=zch_size + mch_size, + feature_names=["f1", "f2"], + ), + ] + ebc = EmbeddingBagCollection( + tables=embedding_bag_configs, + device=device, + ) + ec = EmbeddingCollection( tables=embedding_configs, device=device, ) @@ -351,100 +458,114 @@ def preprocess_func(id: torch.Tensor, hash_size: int) -> torch.Tensor: ), ), } - mcc = ManagedCollisionCollection( + mcc_ec = ManagedCollisionCollection( managed_collision_modules=mc_modules, # pyre-ignore[6] embedding_configs=embedding_configs, ) + mcc_ebc = ManagedCollisionCollection( + managed_collision_modules=deepcopy(mc_modules), + # pyre-ignore[6] + embedding_configs=embedding_bag_configs, + ) mc_ebc = ManagedCollisionEmbeddingBagCollection( ebc, - mcc, + mcc_ebc, return_remapped_features=True, ) - - update_one = KeyedJaggedTensor.from_lengths_sync( - keys=["f1", "f2"], - values=torch.concat( - [ - torch.arange(1000, 1000 + update_size, dtype=torch.int64), - torch.arange( - 1000 + update_size, - 1000 + 2 * update_size, - dtype=torch.int64, - ), - ] - ), - lengths=torch.ones((2 * update_size,), dtype=torch.int64), - weights=None, + mc_ec = ManagedCollisionEmbeddingCollection( + ec, + mcc_ec, + return_remapped_features=True, ) - _, remapped_kjt1 = mc_ebc.forward(update_one) - _, remapped_kjt2 = mc_ebc.forward(update_one) - - assert torch.all( - # pyre-ignore[16] - remapped_kjt1["f1"].values() - == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) - ), "all remapped ids are in mch section" - assert torch.all( - remapped_kjt1["f2"].values() - == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) - ), "all remapped ids are in mch section" - - assert torch.all( - remapped_kjt2["f1"].values() - == torch.cat( - [ - torch.arange(0, 9, dtype=torch.int64), - torch.tensor([19], dtype=torch.int64), # % MCH for last value - ] + mc_modules = [mc_ebc, mc_ec] + + for mc_module in mc_modules: + update_one = KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.concat( + [ + torch.arange(1000, 1000 + update_size, dtype=torch.int64), + torch.arange( + 1000 + update_size, + 1000 + 2 * update_size, + dtype=torch.int64, + ), + ] + ), + lengths=torch.ones((2 * update_size,), dtype=torch.int64), + weights=None, + ) + _, remapped_kjt1 = mc_module.forward(update_one) + _, remapped_kjt2 = mc_module.forward(update_one) + + assert torch.all( + # pyre-ignore[16] + remapped_kjt1["f1"].values() + == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) + ), "all remapped ids are in mch section" + assert torch.all( + remapped_kjt1["f2"].values() + == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) + ), "all remapped ids are in mch section" + + assert torch.all( + remapped_kjt2["f1"].values() + == torch.cat( + [ + torch.arange(0, 9, dtype=torch.int64), + torch.tensor([19], dtype=torch.int64), # % MCH for last value + ] + ) ) - ) - - assert torch.all( - remapped_kjt2["f2"].values() - == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) - ), "all remapped ids are in mch section" - - update_two = KeyedJaggedTensor.from_lengths_sync( - keys=["f1", "f2"], - values=torch.concat( - [ - torch.arange(2000, 2000 + update_size, dtype=torch.int64), - torch.arange( - 1000 + update_size, - 1000 + 2 * update_size, - dtype=torch.int64, - ), - ] - ), - lengths=torch.ones((2 * update_size,), dtype=torch.int64), - weights=None, - ) - - _, remapped_kjt3 = mc_ebc.forward(update_two) - _, remapped_kjt4 = mc_ebc.forward(update_two) - - assert torch.all( - remapped_kjt3["f1"].values() - == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) - ), "all remapped ids are in mch section" - - assert torch.all( - remapped_kjt3["f2"].values() - == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) - ), "all remapped ids are in mch section" - assert torch.all( - remapped_kjt4["f1"].values() - == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) - ), "all remapped ids are in mch section" + assert torch.all( + remapped_kjt2["f2"].values() + == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) + ), "all remapped ids are in mch section" + + update_two = KeyedJaggedTensor.from_lengths_sync( + keys=["f1", "f2"], + values=torch.concat( + [ + torch.arange(2000, 2000 + update_size, dtype=torch.int64), + torch.arange( + 1000 + update_size, + 1000 + 2 * update_size, + dtype=torch.int64, + ), + ] + ), + lengths=torch.ones((2 * update_size,), dtype=torch.int64), + weights=None, + ) - assert torch.all( - remapped_kjt4["f2"].values() - == torch.cat( - [ - torch.arange(0, 9, dtype=torch.int64), - torch.tensor([19], dtype=torch.int64), # assigned first open slot - ] + _, remapped_kjt3 = mc_module.forward(update_two) + _, remapped_kjt4 = mc_module.forward(update_two) + + assert torch.all( + remapped_kjt3["f1"].values() + == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) + ), "all remapped ids are in mch section" + + assert torch.all( + remapped_kjt3["f2"].values() + == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) + ), "all remapped ids are in mch section" + + assert torch.all( + remapped_kjt4["f1"].values() + == torch.arange(zch_size, zch_size + mch_size, dtype=torch.int64) + ), "all remapped ids are in mch section" + + assert torch.all( + remapped_kjt4["f2"].values() + == torch.cat( + [ + torch.arange(0, 9, dtype=torch.int64), + torch.tensor( + [19], dtype=torch.int64 + ), # assigned first open slot + ] + ) ) - )