Skip to content

Commit ffb76fd

Browse files
author
pytorchbot
committed
2024-11-20 nightly release (c2f7d61)
1 parent 8c79542 commit ffb76fd

File tree

4 files changed

+71
-25
lines changed

4 files changed

+71
-25
lines changed

torchrec/inference/inference_legacy/src/SingleGPUExecutor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ SingleGPUExecutor::SingleGPUExecutor(
3838
for (size_t i = 0; i < numProcessThreads_; ++i) {
3939
processExecutor_->add([&]() { process(); });
4040
}
41-
for (const auto& exec_info : execInfos_) {
41+
for ([[maybe_unused]] const auto& exec_info : execInfos_) {
4242
TORCHREC_CHECK(exec_info.interpIdx < manager_->allInstances().size());
4343
}
4444
TORCHREC_CHECK(observer_);

torchrec/inference/tests/test_inference.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import torch
1515
from fbgemm_gpu.split_embedding_configs import SparseType
16+
from torchrec import PoolingType
1617
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
1718
from torchrec.distributed.global_settings import set_propogate_device
1819
from torchrec.distributed.test_utils.test_model import (
@@ -298,3 +299,35 @@ def test_sharded_quantized_tbe_count(self) -> None:
298299
spec[1],
299300
expected_num_embeddings[spec[0]],
300301
)
302+
303+
def test_quantized_tbe_count_different_pooling(self) -> None:
304+
set_propogate_device(True)
305+
306+
self.tables[0].pooling = PoolingType.MEAN
307+
model = TestSparseNN(
308+
tables=self.tables,
309+
weighted_tables=self.weighted_tables,
310+
num_float_features=10,
311+
dense_device=torch.device("cpu"),
312+
sparse_device=torch.device("cpu"),
313+
over_arch_clazz=TestOverArchRegroupModule,
314+
)
315+
316+
model.eval()
317+
_, local_batch = ModelInput.generate(
318+
batch_size=16,
319+
world_size=1,
320+
num_float_features=10,
321+
tables=self.tables,
322+
weighted_tables=self.weighted_tables,
323+
)
324+
325+
model(local_batch[0])
326+
327+
# Quantize the model and collect quantized weights
328+
quantized_model = quantize_inference_model(model)
329+
# We should have 2 TBEs for unweighted ebc as the 2 tables here have different pooling types
330+
self.assertTrue(len(quantized_model.sparse.ebc.tbes) == 2)
331+
self.assertTrue(len(quantized_model.sparse.weighted_ebc.tbes) == 1)
332+
# Changing this back
333+
self.tables[0].pooling = PoolingType.SUM

torchrec/modules/embedding_modules.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,25 @@ class EmbeddingBagCollection(EmbeddingBagCollectionInterface):
102102
For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingBagCollection.
103103
104104
105-
It processes sparse data in the form of `KeyedJaggedTensor` with values of the form
106-
[F X B X L] where:
105+
It is callable on arguments representing sparse data in the form of `KeyedJaggedTensor` with values of the shape
106+
`(F, B, L_{f,i})` where:
107107
108-
* F: features (keys)
109-
* B: batch size
110-
* L: length of sparse features (jagged)
108+
* `F`: number of features (keys)
109+
* `B`: batch size
110+
* `L_{f,i}`: length of sparse features (potentially distinct for each feature `f` and batch index `i`, that is, jagged)
111111
112-
and outputs a `KeyedTensor` with values of the form [B * (F * D)] where:
112+
and outputs a `KeyedTensor` with values with shape `(B, D)` where:
113113
114-
* F: features (keys)
115-
* D: each feature's (key's) embedding dimension
116-
* B: batch size
114+
* `B`: batch size
115+
* `D`: sum of embedding dimensions of all embedding tables, that is, `sum([config.embedding_dim for config in tables])`
116+
117+
Assuming the argument is a `KeyedJaggedTensor` `J` with `F` features, batch size `B` and `L_{f,i}` sparse lengths
118+
such that `J[f][i]` is the bag for feature `f` and batch index `i`, the output `KeyedTensor` `KT` is defined as follows:
119+
`KT[i]` = `torch.cat([emb[f](J[f][i]) for f in J.keys()])` where `emb[f]` is the `EmbeddingBag` corresponding to the feature `f`.
120+
121+
Note that `J[f][i]` is a variable-length list of integer values (a bag), and `emb[f](J[f][i])` is pooled embedding
122+
produced by reducing the embeddings of each of the values in `J[f][i]`
123+
using the `EmbeddingBag` `emb[f]`'s mode (default is the mean).
117124
118125
Args:
119126
tables (List[EmbeddingBagConfig]): list of embedding tables.
@@ -131,28 +138,34 @@ class EmbeddingBagCollection(EmbeddingBagCollectionInterface):
131138
132139
ebc = EmbeddingBagCollection(tables=[table_0, table_1])
133140
134-
# 0 1 2 <-- batch
135-
# "f1" [0,1] None [2]
136-
# "f2" [3] [4] [5,6,7]
141+
# i = 0 i = 1 i = 2 <-- batch indices
142+
# "f1" [0,1] None [2]
143+
# "f2" [3] [4] [5,6,7]
137144
# ^
138-
# feature
145+
# features
139146
140147
features = KeyedJaggedTensor(
141148
keys=["f1", "f2"],
142-
values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
143-
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
149+
values=torch.tensor([0, 1, 2, # feature 'f1'
150+
3, 4, 5, 6, 7]), # feature 'f2'
151+
# i = 1 i = 2 i = 3 <--- batch indices
152+
offsets=torch.tensor([
153+
0, 2, 2, # 'f1' bags are values[0:2], values[2:2], and values[2:3]
154+
3, 4, 5, 8]), # 'f2' bags are values[3:4], values[4:5], and values[5:8]
144155
)
145156
146157
pooled_embeddings = ebc(features)
147158
print(pooled_embeddings.values())
148-
tensor([[-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783],
149-
[ 0.0000, 0.0000, 0.0000, 0.1598, 0.0695, 1.3265, -0.1011],
150-
[-0.4256, -1.1846, -2.1648, -1.0893, 0.3590, -1.9784, -0.7681]],
159+
tensor([
160+
# f1 pooled embeddings from bags (dim 3) f2 pooled embeddings from bags (dim 4)
161+
[-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783], # batch index 0
162+
[ 0.0000, 0.0000, 0.0000, 0.1598, 0.0695, 1.3265, -0.1011], # batch index 1
163+
[-0.4256, -1.1846, -2.1648, -1.0893, 0.3590, -1.9784, -0.7681]], # batch index 2
151164
grad_fn=<CatBackward0>)
152165
print(pooled_embeddings.keys())
153166
['f1', 'f2']
154167
print(pooled_embeddings.offset_per_key())
155-
tensor([0, 3, 7])
168+
tensor([0, 3, 7]) # embeddings have dimensions 3 and 4, so embeddings are at [0, 3) and [3, 7).
156169
"""
157170

158171
def __init__(

torchrec/quant/embedding_modules.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -382,15 +382,14 @@ def __init__(
382382
if table.name in table_names:
383383
raise ValueError(f"Duplicate table name {table.name}")
384384
table_names.add(table.name)
385-
key = (table.pooling, table.data_type)
386-
self._key_to_tables[key].append(table)
385+
# pyre-ignore
386+
self._key_to_tables[table.pooling].append(table)
387387

388388
location = (
389389
EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE
390390
)
391391

392-
for key, emb_configs in self._key_to_tables.items():
393-
(pooling, data_type) = key
392+
for pooling, emb_configs in self._key_to_tables.items():
394393
embedding_specs = []
395394
weight_lists: Optional[
396395
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
@@ -409,7 +408,7 @@ def __init__(
409408
else table.num_embeddings
410409
),
411410
table.embedding_dim,
412-
data_type_to_sparse_type(data_type),
411+
data_type_to_sparse_type(table.data_type),
413412
location,
414413
)
415414
)
@@ -421,6 +420,7 @@ def __init__(
421420

422421
emb_module = IntNBitTableBatchedEmbeddingBagsCodegen(
423422
embedding_specs=embedding_specs,
423+
# pyre-ignore
424424
pooling_mode=pooling_type_to_pooling_mode(pooling),
425425
weight_lists=weight_lists,
426426
device=device,

0 commit comments

Comments
 (0)