Skip to content

Commit

Permalink
2024-11-20 nightly release (c2f7d61)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Nov 20, 2024
1 parent 8c79542 commit ffb76fd
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ SingleGPUExecutor::SingleGPUExecutor(
for (size_t i = 0; i < numProcessThreads_; ++i) {
processExecutor_->add([&]() { process(); });
}
for (const auto& exec_info : execInfos_) {
for ([[maybe_unused]] const auto& exec_info : execInfos_) {
TORCHREC_CHECK(exec_info.interpIdx < manager_->allInstances().size());
}
TORCHREC_CHECK(observer_);
Expand Down
33 changes: 33 additions & 0 deletions torchrec/inference/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch
from fbgemm_gpu.split_embedding_configs import SparseType
from torchrec import PoolingType
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
from torchrec.distributed.global_settings import set_propogate_device
from torchrec.distributed.test_utils.test_model import (
Expand Down Expand Up @@ -298,3 +299,35 @@ def test_sharded_quantized_tbe_count(self) -> None:
spec[1],
expected_num_embeddings[spec[0]],
)

def test_quantized_tbe_count_different_pooling(self) -> None:
set_propogate_device(True)

self.tables[0].pooling = PoolingType.MEAN
model = TestSparseNN(
tables=self.tables,
weighted_tables=self.weighted_tables,
num_float_features=10,
dense_device=torch.device("cpu"),
sparse_device=torch.device("cpu"),
over_arch_clazz=TestOverArchRegroupModule,
)

model.eval()
_, local_batch = ModelInput.generate(
batch_size=16,
world_size=1,
num_float_features=10,
tables=self.tables,
weighted_tables=self.weighted_tables,
)

model(local_batch[0])

# Quantize the model and collect quantized weights
quantized_model = quantize_inference_model(model)
# We should have 2 TBEs for unweighted ebc as the 2 tables here have different pooling types
self.assertTrue(len(quantized_model.sparse.ebc.tbes) == 2)
self.assertTrue(len(quantized_model.sparse.weighted_ebc.tbes) == 1)
# Changing this back
self.tables[0].pooling = PoolingType.SUM
51 changes: 32 additions & 19 deletions torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,25 @@ class EmbeddingBagCollection(EmbeddingBagCollectionInterface):
For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingBagCollection.
It processes sparse data in the form of `KeyedJaggedTensor` with values of the form
[F X B X L] where:
It is callable on arguments representing sparse data in the form of `KeyedJaggedTensor` with values of the shape
`(F, B, L_{f,i})` where:
* F: features (keys)
* B: batch size
* L: length of sparse features (jagged)
* `F`: number of features (keys)
* `B`: batch size
* `L_{f,i}`: length of sparse features (potentially distinct for each feature `f` and batch index `i`, that is, jagged)
and outputs a `KeyedTensor` with values of the form [B * (F * D)] where:
and outputs a `KeyedTensor` with values with shape `(B, D)` where:
* F: features (keys)
* D: each feature's (key's) embedding dimension
* B: batch size
* `B`: batch size
* `D`: sum of embedding dimensions of all embedding tables, that is, `sum([config.embedding_dim for config in tables])`
Assuming the argument is a `KeyedJaggedTensor` `J` with `F` features, batch size `B` and `L_{f,i}` sparse lengths
such that `J[f][i]` is the bag for feature `f` and batch index `i`, the output `KeyedTensor` `KT` is defined as follows:
`KT[i]` = `torch.cat([emb[f](J[f][i]) for f in J.keys()])` where `emb[f]` is the `EmbeddingBag` corresponding to the feature `f`.
Note that `J[f][i]` is a variable-length list of integer values (a bag), and `emb[f](J[f][i])` is pooled embedding
produced by reducing the embeddings of each of the values in `J[f][i]`
using the `EmbeddingBag` `emb[f]`'s mode (default is the mean).
Args:
tables (List[EmbeddingBagConfig]): list of embedding tables.
Expand All @@ -131,28 +138,34 @@ class EmbeddingBagCollection(EmbeddingBagCollectionInterface):
ebc = EmbeddingBagCollection(tables=[table_0, table_1])
# 0 1 2 <-- batch
# "f1" [0,1] None [2]
# "f2" [3] [4] [5,6,7]
# i = 0 i = 1 i = 2 <-- batch indices
# "f1" [0,1] None [2]
# "f2" [3] [4] [5,6,7]
# ^
# feature
# features
features = KeyedJaggedTensor(
keys=["f1", "f2"],
values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
values=torch.tensor([0, 1, 2, # feature 'f1'
3, 4, 5, 6, 7]), # feature 'f2'
# i = 1 i = 2 i = 3 <--- batch indices
offsets=torch.tensor([
0, 2, 2, # 'f1' bags are values[0:2], values[2:2], and values[2:3]
3, 4, 5, 8]), # 'f2' bags are values[3:4], values[4:5], and values[5:8]
)
pooled_embeddings = ebc(features)
print(pooled_embeddings.values())
tensor([[-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783],
[ 0.0000, 0.0000, 0.0000, 0.1598, 0.0695, 1.3265, -0.1011],
[-0.4256, -1.1846, -2.1648, -1.0893, 0.3590, -1.9784, -0.7681]],
tensor([
# f1 pooled embeddings from bags (dim 3) f2 pooled embeddings from bags (dim 4)
[-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783], # batch index 0
[ 0.0000, 0.0000, 0.0000, 0.1598, 0.0695, 1.3265, -0.1011], # batch index 1
[-0.4256, -1.1846, -2.1648, -1.0893, 0.3590, -1.9784, -0.7681]], # batch index 2
grad_fn=<CatBackward0>)
print(pooled_embeddings.keys())
['f1', 'f2']
print(pooled_embeddings.offset_per_key())
tensor([0, 3, 7])
tensor([0, 3, 7]) # embeddings have dimensions 3 and 4, so embeddings are at [0, 3) and [3, 7).
"""

def __init__(
Expand Down
10 changes: 5 additions & 5 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,15 +382,14 @@ def __init__(
if table.name in table_names:
raise ValueError(f"Duplicate table name {table.name}")
table_names.add(table.name)
key = (table.pooling, table.data_type)
self._key_to_tables[key].append(table)
# pyre-ignore
self._key_to_tables[table.pooling].append(table)

location = (
EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE
)

for key, emb_configs in self._key_to_tables.items():
(pooling, data_type) = key
for pooling, emb_configs in self._key_to_tables.items():
embedding_specs = []
weight_lists: Optional[
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
Expand All @@ -409,7 +408,7 @@ def __init__(
else table.num_embeddings
),
table.embedding_dim,
data_type_to_sparse_type(data_type),
data_type_to_sparse_type(table.data_type),
location,
)
)
Expand All @@ -421,6 +420,7 @@ def __init__(

emb_module = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs=embedding_specs,
# pyre-ignore
pooling_mode=pooling_type_to_pooling_mode(pooling),
weight_lists=weight_lists,
device=device,
Expand Down

0 comments on commit ffb76fd

Please sign in to comment.