@@ -102,18 +102,25 @@ class EmbeddingBagCollection(EmbeddingBagCollectionInterface):
102
102
For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingBagCollection.
103
103
104
104
105
- It processes sparse data in the form of `KeyedJaggedTensor` with values of the form
106
- [F X B X L] where:
105
+ It is callable on arguments representing sparse data in the form of `KeyedJaggedTensor` with values of the shape
106
+ `(F, B, L_{f,i})` where:
107
107
108
- * F: features (keys)
109
- * B : batch size
110
- * L : length of sparse features (jagged)
108
+ * `F`: number of features (keys)
109
+ * `B` : batch size
110
+ * `L_{f,i}` : length of sparse features (potentially distinct for each feature `f` and batch index `i`, that is, jagged)
111
111
112
- and outputs a `KeyedTensor` with values of the form [B * (F * D)] where:
112
+ and outputs a `KeyedTensor` with values with shape `(B, D)` where:
113
113
114
- * F: features (keys)
115
- * D: each feature's (key's) embedding dimension
116
- * B: batch size
114
+ * `B`: batch size
115
+ * `D`: sum of embedding dimensions of all embedding tables, that is, `sum([config.embedding_dim for config in tables])`
116
+
117
+ Assuming the argument is a `KeyedJaggedTensor` `J` with `F` features, batch size `B` and `L_{f,i}` sparse lengths
118
+ such that `J[f][i]` is the bag for feature `f` and batch index `i`, the output `KeyedTensor` `KT` is defined as follows:
119
+ `KT[i]` = `torch.cat([emb[f](J[f][i]) for f in J.keys()])` where `emb[f]` is the `EmbeddingBag` corresponding to the feature `f`.
120
+
121
+ Note that `J[f][i]` is a variable-length list of integer values (a bag), and `emb[f](J[f][i])` is pooled embedding
122
+ produced by reducing the embeddings of each of the values in `J[f][i]`
123
+ using the `EmbeddingBag` `emb[f]`'s mode (default is the mean).
117
124
118
125
Args:
119
126
tables (List[EmbeddingBagConfig]): list of embedding tables.
@@ -131,28 +138,34 @@ class EmbeddingBagCollection(EmbeddingBagCollectionInterface):
131
138
132
139
ebc = EmbeddingBagCollection(tables=[table_0, table_1])
133
140
134
- # 0 1 2 <-- batch
135
- # "f1" [0,1] None [2]
136
- # "f2" [3] [4] [5,6,7]
141
+ # i = 0 i = 1 i = 2 <-- batch indices
142
+ # "f1" [0,1] None [2]
143
+ # "f2" [3] [4] [5,6,7]
137
144
# ^
138
- # feature
145
+ # features
139
146
140
147
features = KeyedJaggedTensor(
141
148
keys=["f1", "f2"],
142
- values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
143
- offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
149
+ values=torch.tensor([0, 1, 2, # feature 'f1'
150
+ 3, 4, 5, 6, 7]), # feature 'f2'
151
+ # i = 1 i = 2 i = 3 <--- batch indices
152
+ offsets=torch.tensor([
153
+ 0, 2, 2, # 'f1' bags are values[0:2], values[2:2], and values[2:3]
154
+ 3, 4, 5, 8]), # 'f2' bags are values[3:4], values[4:5], and values[5:8]
144
155
)
145
156
146
157
pooled_embeddings = ebc(features)
147
158
print(pooled_embeddings.values())
148
- tensor([[-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783],
149
- [ 0.0000, 0.0000, 0.0000, 0.1598, 0.0695, 1.3265, -0.1011],
150
- [-0.4256, -1.1846, -2.1648, -1.0893, 0.3590, -1.9784, -0.7681]],
159
+ tensor([
160
+ # f1 pooled embeddings from bags (dim 3) f2 pooled embeddings from bags (dim 4)
161
+ [-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783], # batch index 0
162
+ [ 0.0000, 0.0000, 0.0000, 0.1598, 0.0695, 1.3265, -0.1011], # batch index 1
163
+ [-0.4256, -1.1846, -2.1648, -1.0893, 0.3590, -1.9784, -0.7681]], # batch index 2
151
164
grad_fn=<CatBackward0>)
152
165
print(pooled_embeddings.keys())
153
166
['f1', 'f2']
154
167
print(pooled_embeddings.offset_per_key())
155
- tensor([0, 3, 7])
168
+ tensor([0, 3, 7]) # embeddings have dimensions 3 and 4, so embeddings are at [0, 3) and [3, 7).
156
169
"""
157
170
158
171
def __init__ (
0 commit comments