diff --git a/torchrec/inference/inference_legacy/src/SingleGPUExecutor.cpp b/torchrec/inference/inference_legacy/src/SingleGPUExecutor.cpp index 38da5527f..828d88b1a 100644 --- a/torchrec/inference/inference_legacy/src/SingleGPUExecutor.cpp +++ b/torchrec/inference/inference_legacy/src/SingleGPUExecutor.cpp @@ -38,7 +38,7 @@ SingleGPUExecutor::SingleGPUExecutor( for (size_t i = 0; i < numProcessThreads_; ++i) { processExecutor_->add([&]() { process(); }); } - for (const auto& exec_info : execInfos_) { + for ([[maybe_unused]] const auto& exec_info : execInfos_) { TORCHREC_CHECK(exec_info.interpIdx < manager_->allInstances().size()); } TORCHREC_CHECK(observer_); diff --git a/torchrec/inference/tests/test_inference.py b/torchrec/inference/tests/test_inference.py index 4560e1f8b..d0ad0469a 100644 --- a/torchrec/inference/tests/test_inference.py +++ b/torchrec/inference/tests/test_inference.py @@ -13,6 +13,7 @@ import torch from fbgemm_gpu.split_embedding_configs import SparseType +from torchrec import PoolingType from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES from torchrec.distributed.global_settings import set_propogate_device from torchrec.distributed.test_utils.test_model import ( @@ -298,3 +299,35 @@ def test_sharded_quantized_tbe_count(self) -> None: spec[1], expected_num_embeddings[spec[0]], ) + + def test_quantized_tbe_count_different_pooling(self) -> None: + set_propogate_device(True) + + self.tables[0].pooling = PoolingType.MEAN + model = TestSparseNN( + tables=self.tables, + weighted_tables=self.weighted_tables, + num_float_features=10, + dense_device=torch.device("cpu"), + sparse_device=torch.device("cpu"), + over_arch_clazz=TestOverArchRegroupModule, + ) + + model.eval() + _, local_batch = ModelInput.generate( + batch_size=16, + world_size=1, + num_float_features=10, + tables=self.tables, + weighted_tables=self.weighted_tables, + ) + + model(local_batch[0]) + + # Quantize the model and collect quantized weights + quantized_model = quantize_inference_model(model) + # We should have 2 TBEs for unweighted ebc as the 2 tables here have different pooling types + self.assertTrue(len(quantized_model.sparse.ebc.tbes) == 2) + self.assertTrue(len(quantized_model.sparse.weighted_ebc.tbes) == 1) + # Changing this back + self.tables[0].pooling = PoolingType.SUM diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 1fdfae4b4..132d3e18e 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -102,18 +102,25 @@ class EmbeddingBagCollection(EmbeddingBagCollectionInterface): For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingBagCollection. - It processes sparse data in the form of `KeyedJaggedTensor` with values of the form - [F X B X L] where: + It is callable on arguments representing sparse data in the form of `KeyedJaggedTensor` with values of the shape + `(F, B, L_{f,i})` where: - * F: features (keys) - * B: batch size - * L: length of sparse features (jagged) + * `F`: number of features (keys) + * `B`: batch size + * `L_{f,i}`: length of sparse features (potentially distinct for each feature `f` and batch index `i`, that is, jagged) - and outputs a `KeyedTensor` with values of the form [B * (F * D)] where: + and outputs a `KeyedTensor` with values with shape `(B, D)` where: - * F: features (keys) - * D: each feature's (key's) embedding dimension - * B: batch size + * `B`: batch size + * `D`: sum of embedding dimensions of all embedding tables, that is, `sum([config.embedding_dim for config in tables])` + + Assuming the argument is a `KeyedJaggedTensor` `J` with `F` features, batch size `B` and `L_{f,i}` sparse lengths + such that `J[f][i]` is the bag for feature `f` and batch index `i`, the output `KeyedTensor` `KT` is defined as follows: + `KT[i]` = `torch.cat([emb[f](J[f][i]) for f in J.keys()])` where `emb[f]` is the `EmbeddingBag` corresponding to the feature `f`. + + Note that `J[f][i]` is a variable-length list of integer values (a bag), and `emb[f](J[f][i])` is pooled embedding + produced by reducing the embeddings of each of the values in `J[f][i]` + using the `EmbeddingBag` `emb[f]`'s mode (default is the mean). Args: tables (List[EmbeddingBagConfig]): list of embedding tables. @@ -131,28 +138,34 @@ class EmbeddingBagCollection(EmbeddingBagCollectionInterface): ebc = EmbeddingBagCollection(tables=[table_0, table_1]) - # 0 1 2 <-- batch - # "f1" [0,1] None [2] - # "f2" [3] [4] [5,6,7] + # i = 0 i = 1 i = 2 <-- batch indices + # "f1" [0,1] None [2] + # "f2" [3] [4] [5,6,7] # ^ - # feature + # features features = KeyedJaggedTensor( keys=["f1", "f2"], - values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + values=torch.tensor([0, 1, 2, # feature 'f1' + 3, 4, 5, 6, 7]), # feature 'f2' + # i = 1 i = 2 i = 3 <--- batch indices + offsets=torch.tensor([ + 0, 2, 2, # 'f1' bags are values[0:2], values[2:2], and values[2:3] + 3, 4, 5, 8]), # 'f2' bags are values[3:4], values[4:5], and values[5:8] ) pooled_embeddings = ebc(features) print(pooled_embeddings.values()) - tensor([[-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783], - [ 0.0000, 0.0000, 0.0000, 0.1598, 0.0695, 1.3265, -0.1011], - [-0.4256, -1.1846, -2.1648, -1.0893, 0.3590, -1.9784, -0.7681]], + tensor([ + # f1 pooled embeddings from bags (dim 3) f2 pooled embeddings from bags (dim 4) + [-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783], # batch index 0 + [ 0.0000, 0.0000, 0.0000, 0.1598, 0.0695, 1.3265, -0.1011], # batch index 1 + [-0.4256, -1.1846, -2.1648, -1.0893, 0.3590, -1.9784, -0.7681]], # batch index 2 grad_fn=) print(pooled_embeddings.keys()) ['f1', 'f2'] print(pooled_embeddings.offset_per_key()) - tensor([0, 3, 7]) + tensor([0, 3, 7]) # embeddings have dimensions 3 and 4, so embeddings are at [0, 3) and [3, 7). """ def __init__( diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 8bcbc0bb2..e36b776a7 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -382,15 +382,14 @@ def __init__( if table.name in table_names: raise ValueError(f"Duplicate table name {table.name}") table_names.add(table.name) - key = (table.pooling, table.data_type) - self._key_to_tables[key].append(table) + # pyre-ignore + self._key_to_tables[table.pooling].append(table) location = ( EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE ) - for key, emb_configs in self._key_to_tables.items(): - (pooling, data_type) = key + for pooling, emb_configs in self._key_to_tables.items(): embedding_specs = [] weight_lists: Optional[ List[Tuple[torch.Tensor, Optional[torch.Tensor]]] @@ -409,7 +408,7 @@ def __init__( else table.num_embeddings ), table.embedding_dim, - data_type_to_sparse_type(data_type), + data_type_to_sparse_type(table.data_type), location, ) ) @@ -421,6 +420,7 @@ def __init__( emb_module = IntNBitTableBatchedEmbeddingBagsCodegen( embedding_specs=embedding_specs, + # pyre-ignore pooling_mode=pooling_type_to_pooling_mode(pooling), weight_lists=weight_lists, device=device,