Skip to content

Commit 8c88220

Browse files
authored
Add Custom OP:Merged EmbeddingBag (#355)
* Add Custom OP:Merged EmbeddingBag * remove commnents for merge input * rename merge_input to linearize_indices_and_offsets * move need_linearize_indices_and_offsets as part of input of forward
1 parent 963df91 commit 8c88220

14 files changed

+1504
-50
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .frozen_batch_norm import FrozenBatchNorm2d
22
from . import _roi_align
3+
from .merged_embeddingbag import MergedEmbeddingBagWithSGD
34
from .linear_fuse_eltwise import IPEXLinearEltwise
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
import torch
2+
from torch import Tensor, nn
3+
from torch.autograd import Function
4+
from typing import List, Optional, NamedTuple
5+
from itertools import accumulate
6+
import enum
7+
8+
class PoolingMode(enum.IntEnum):
9+
SUM = 0
10+
MEAN = 1
11+
12+
class SGDArgs(NamedTuple):
13+
bf16_trail: List[Optional[torch.Tensor]]
14+
weight_decay: float
15+
lr: float
16+
17+
class EmbeddingSpec(NamedTuple):
18+
num_of_features: int
19+
feature_size: int
20+
pooling_modes: str
21+
dtype: torch.dtype
22+
weight: Optional[torch.Tensor]
23+
24+
def merged_embeddingbag_sgd(
25+
indices,
26+
offsets,
27+
indices_with_row_offsets,
28+
row_offsets,
29+
pooling_modes,
30+
sgd_args,
31+
*weights
32+
):
33+
if torch.is_grad_enabled():
34+
return MergedEmbeddingBagSGDFunc.apply(
35+
indices, offsets, indices_with_row_offsets, row_offsets, pooling_modes, sgd_args, *weights
36+
)
37+
return torch.ops.torch_ipex.merged_embeddingbag_forward(indices, offsets, weights, pooling_modes)
38+
39+
class MergedEmbeddingBagSGDFunc(Function):
40+
@staticmethod
41+
def unpack(*args):
42+
return args
43+
44+
@staticmethod
45+
def forward(ctx, indices, offsets, indices_with_row_offsets, row_offsets, pooling_modes, sgd_args, *weights):
46+
output = torch.ops.torch_ipex.merged_embeddingbag_forward(
47+
indices, offsets, weights, pooling_modes
48+
)
49+
ctx.indices = indices
50+
ctx.offsets = offsets
51+
ctx.weights = weights
52+
ctx.indices_with_row_offsets = indices_with_row_offsets
53+
ctx.row_offsets = row_offsets
54+
ctx.pooling_modes = pooling_modes
55+
ctx.sgd_args = sgd_args
56+
return MergedEmbeddingBagSGDFunc.unpack(*output)
57+
58+
@staticmethod
59+
def backward(ctx, *grad_out):
60+
indices = ctx.indices
61+
offsets = ctx.offsets
62+
weights = ctx.weights
63+
indices_with_row_offsets = ctx.indices_with_row_offsets
64+
row_offsets = ctx.row_offsets
65+
pooling_modes = ctx.pooling_modes
66+
sgd_args = ctx.sgd_args
67+
bf16_trail = sgd_args.bf16_trail
68+
weight_decay = sgd_args.weight_decay
69+
lr = sgd_args.lr
70+
torch.ops.torch_ipex.merged_embeddingbag_backward_sgd(
71+
grad_out, indices, offsets, weights, indices_with_row_offsets,
72+
row_offsets, pooling_modes,
73+
bf16_trail, weight_decay, lr)
74+
n_tables = len(weights)
75+
output = [None for i in range(n_tables + 6)]
76+
return MergedEmbeddingBagSGDFunc.unpack(*output)
77+
78+
class MergedEmbeddingBag(nn.Module):
79+
r"""
80+
Merge multiple Pytorch EmbeddingBag (https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py#L221)
81+
as one torch.nn.Module.
82+
Native usage for multiple EmbeddingBag is:
83+
>>> EmbLists = torch.nn.Modulist(emb1, emb2, emb3, ..., emb_m)
84+
>>> inputs = [in1, in2, in3, ..., in_m]
85+
>>> outputs = []
86+
>>> for i in range(len(EmbLists)):
87+
>>> outputs.append(Emb[in_i])
88+
Our optimized path will be:
89+
>>> merged_emb = MergedEmbeddingBagWithSGD(args)
90+
>>> outputs = MergedEmbeddingBagWithSGD(input)
91+
We will have benefits for our optimized path:
92+
1). We will save Pytorch OP dispatch overhead, if the EmbeddingBag operations are not
93+
heavy, this dispatch overhead will have big impact.
94+
We introduce "linearize_indices_and_offsets" step to merged indices/offsets together. But consider EmbeddingBags
95+
are usually the first layer of a model. So "linearize_indices_and_offsets" can be considered as "data prepocess" and
96+
can be done offline.
97+
This Module can not be used alone, we suggest to use MergedEmbeddingBagWith[Optimizer] instead.
98+
Now we can only choose MergedEmbeddingBagWithSGD and we plan to add more optimizer support
99+
in the future.
100+
For the introduction of MergedEmbeddingBagWith[Optimizer], please find the comments at
101+
MergedEmbeddingBagWithSGD.
102+
"""
103+
embedding_specs: List[EmbeddingSpec]
104+
105+
def __init__(
106+
self,
107+
embedding_specs: List[EmbeddingSpec],
108+
):
109+
super(MergedEmbeddingBag, self).__init__()
110+
self.n_tables = len(embedding_specs)
111+
self.weights = []
112+
row_offsets = []
113+
feature_sizes = []
114+
self.pooling_modes = []
115+
self.dtypes = []
116+
dtype = None
117+
self.weights = torch.nn.ParameterList([nn.Parameter(torch.Tensor()) for i in range(len(embedding_specs))])
118+
for i, emb in enumerate(embedding_specs):
119+
num_of_features, feature_size, mode, dtype, weight = emb
120+
row_offsets.append(num_of_features)
121+
if mode == 'sum':
122+
self.pooling_modes.append(PoolingMode.SUM)
123+
elif mode == 'mean':
124+
self.pooling_modes.append(PoolingMode.MEAN)
125+
else:
126+
assert False, r"MergedEmbeddingBag only support EmbeddingBag with model sum or mean"
127+
if weight is None:
128+
weight = torch.empty((num_of_features, feature_size), dtype=dtype)
129+
self.weights[i] = nn.Parameter(weight)
130+
self.register_buffer(
131+
"row_offsets",
132+
torch.tensor([0] + list(accumulate(row_offsets)), dtype=torch.int64),
133+
)
134+
135+
def extra_repr(self) -> str:
136+
s = 'number of tables={}\n'.format(self.n_tables)
137+
for i in range(self.n_tables):
138+
s += "table{}: {}, {}, {}, {}".format(
139+
i, self.weights[i].shape[0], self.weights[i].shape[1], self.pooling_modes[i], self.weights[i].dtype)
140+
if i != self.n_tables - 1:
141+
s += '\n'
142+
return s
143+
144+
def linearize_indices_and_offsets(
145+
self,
146+
indices: List[Tensor],
147+
offsets: List[Optional[Tensor]],
148+
include_last_offsets: List[bool]
149+
):
150+
r"""
151+
To make backward/update more balance, we only have 1 logical table in MergedEmbedingBag and
152+
use unified indices for access the whole logical table.
153+
We need to re-mark the indice from different tables to distinguish them.
154+
For example, we have 2 tables with shape [200, 128] and [100, 128].
155+
The indice 50 for table1 is still 50 and the indice 50 for table2 should be set to 50 + 200 = 250.
156+
We assume the original indice and offset will follow the usage for Pytorch EmbeddingBag:
157+
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py#L355-L382
158+
"""
159+
# TODO: support per_sample_weights in forward
160+
def get_batch_size(indice, offset, include_last_offset):
161+
if indice.dim() == 2:
162+
assert offset is None, "offset should be None if indice is 2-D tensor, https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py#L355-L382"
163+
batch_size = indice.shape[0]
164+
else:
165+
batch_size = offset.numel()
166+
if include_last_offset:
167+
batch_size -= 1
168+
return batch_size
169+
170+
assert self.n_tables == len(indices), "expected {} but got {} indices".format(self.n_tables, len(indices))
171+
assert self.n_tables == len(offsets), "expected {} but got {} offsets".format(self.n_tables, len(offsets))
172+
assert self.n_tables == len(include_last_offsets), "expected {} but got {} include_last_offsets".format(
173+
self.n_tables, len(include_last_offsets))
174+
175+
batch_size = get_batch_size(indices[0], offsets[0], include_last_offsets[0])
176+
assert all(
177+
batch_size == get_batch_size(idx, offset, include_last) for idx, offset, include_last in zip(indices, offsets, include_last_offsets)
178+
), r"MergedEmbeddingBag only support input with same batch size"
179+
n_indices = sum([t.numel() for t in indices])
180+
n_offsets = batch_size * self.n_tables + 1 # include last offset
181+
merged_indices = torch.empty(n_indices, dtype=torch.int64)
182+
merged_indices_with_row_offsets = torch.empty(n_indices, dtype=torch.int64) # used for sort together
183+
merged_offsets = torch.empty(n_offsets, dtype=torch.int64)
184+
idx_start = 0
185+
offset_start = 0
186+
for i in range(self.n_tables):
187+
n_indice = indices[i].numel()
188+
merged_indices[idx_start: idx_start + n_indice].copy_(indices[i].view(-1))
189+
merged_indices_with_row_offsets[idx_start: idx_start + n_indice].copy_(indices[i].view(-1) + self.row_offsets[i])
190+
if indices[i].dim() == 2:
191+
bag_size = indices[i].shape[1]
192+
offset = torch.arange(0, indices[i].numel(), bag_size)
193+
else:
194+
offset = offsets[i][:-1] if include_last_offsets[i] else offsets[i]
195+
assert offset.numel() == batch_size
196+
merged_offsets[offset_start : offset_start + batch_size].copy_(offset + idx_start)
197+
idx_start += n_indice
198+
offset_start += batch_size
199+
assert idx_start == n_indices
200+
assert offset_start == n_offsets - 1
201+
merged_offsets[-1] = n_indices
202+
return (merged_indices, merged_offsets, merged_indices_with_row_offsets)
203+
204+
def forward(self, input, need_linearize_indices_and_offsets=torch.BoolTensor([True])):
205+
assert False, "Please use MergedEmbeddingBagWith[Optimizer]. We only support SGD now, so please create module MergedEmbeddingBagWithSGD instead"
206+
207+
208+
class MergedEmbeddingBagWithSGD(MergedEmbeddingBag):
209+
r"""
210+
To support training for MergedEmbeddingBag with good performance, we fused optimizer step
211+
with backward function.
212+
Native usage for multiple EmbeddingBag is:
213+
>>> EmbLists = torch.nn.Modulist(emb1, emb2, emb3, ..., emb_m)
214+
>>> sgd = torch.optim.SGD(EmbLists.parameters(), lr=lr, weight_decay=weight_decay)
215+
>>> inputs = [in1, in2, in3, ..., in_m]
216+
>>> outputs = []
217+
>>> for i in range(len(EmbLists)):
218+
>>> outputs.append(Emb[in_i])
219+
>>> sgd.zero_grad()
220+
>>> for i in range(len(outputs)):
221+
>>> out.backward(grads[i])
222+
>>> sgd.step()
223+
Our optimized path will be:
224+
>>> # create MergedEmbeddingBagWithSGD module with optimizer args (lr and weight decay)
225+
>>> merged_emb = MergedEmbeddingBagWithSGD(args)
226+
>>> merged_input = merged_emb.linearize_indices_and_offsets(inputs)
227+
>>> outputs = MergedEmbeddingBagWithSGD(merged_input)
228+
>>> outputs.backward(grads)
229+
We will get further benefits in training:
230+
1). We will futher save Pytorch OP dispatch overhead in backward and weight update process.
231+
2). We will make thread loading more balance during backward/weight update. In real
232+
world scenario, Embedingbag are often used to represent categorical features and the
233+
categorical features will often fit power law distribution. For example, if we use one
234+
Embeddingtable to represent the age range the users of a video game website. We might
235+
find most of users are from 10-19 or 20-29. So we may need update the row which represent
236+
10-19 or 20-29 frequently. Since update these rows need to write at the same memory address,
237+
we need to write it by 1 thread (or we will have write conflict or have overhead to solve the conflict).
238+
By merge multiple table together, we will have more friendly distribution to distribute
239+
backward/update tasks.
240+
3). We will fuse update with backward together. We can immediately update the weight after
241+
we get grad from backward thus the memory pattern will be more friendly. We will have
242+
more chance to access data from cache.
243+
"""
244+
embedding_specs: List[EmbeddingSpec]
245+
246+
def __init__(
247+
self,
248+
embedding_specs: List[EmbeddingSpec],
249+
lr: float = 0.01,
250+
weight_decay: float = 0
251+
):
252+
super(MergedEmbeddingBagWithSGD, self).__init__(embedding_specs)
253+
self.sgd_args = self.init_sgd_args(lr, weight_decay)
254+
for i in range(self.n_tables):
255+
weight = self.weights[i]
256+
if weight.dtype == torch.bfloat16:
257+
self.sgd_args.bf16_trail.append(torch.zeros_like(weight, dtype=torch.bfloat16))
258+
else:
259+
self.sgd_args.bf16_trail.append(torch.empty(0, dtype=torch.bfloat16))
260+
261+
def init_sgd_args(self, lr, weight_decay, bf16_trail=[]):
262+
if lr < 0.0:
263+
raise ValueError("Invalid learning rate: {}".format(lr))
264+
if weight_decay < 0.0:
265+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
266+
return SGDArgs(
267+
weight_decay=weight_decay,
268+
lr=lr,
269+
bf16_trail=bf16_trail
270+
)
271+
272+
def to_bfloat16_train(self):
273+
r"""
274+
Cast weight to bf16 and it's trail part for training
275+
"""
276+
trails = []
277+
for i in range(len(self.weights)):
278+
if self.weights[i].dtype == torch.float:
279+
bf16_w, trail = torch.ops.torch_ipex.split_float_bfloat16(self.weights[i])
280+
elif self.weights[i].dtype == torch.bfloat16:
281+
bf16_w = self.weights[i]
282+
trail = torch.zeros_like(bf16_w, dtype=torch.bfloat16)
283+
elif self.weights[i].dtype == torch.double:
284+
bf16_w, trail = torch.ops.torch_ipex.split_float_bfloat16(self.weights[i].float())
285+
else:
286+
assert False, r"MergedEmbeddingBag only support dtypes with bfloat, float and double"
287+
trails.append(trail)
288+
self.weights[i] = torch.nn.Parameter(bf16_w)
289+
self.sgd_args = self.sgd_args._replace(bf16_trail=trails)
290+
291+
def forward(self, input, need_linearize_indices_and_offsets=torch.BoolTensor([True])):
292+
r"""
293+
Args:
294+
input (Tuple[Tensor]): a tuple of (indices, offsets, include_last_offsets(if not merged)/indices_with_row_offsets(if merged))
295+
need_linearize_indices_and_offsets: indicate whether input need to be linearized
296+
Returns:
297+
List[Tensor] output shape of `(batch_size, feature_size)` which length = num of tables.
298+
"""
299+
if need_linearize_indices_and_offsets.item():
300+
indices, offsets, include_last_offsets = input
301+
indices, offsets, indices_with_row_offsets = self.linearize_indices_and_offsets(indices, offsets, include_last_offsets)
302+
else:
303+
indices, offsets, indices_with_row_offsets = input
304+
return merged_embeddingbag_sgd(
305+
indices, offsets, indices_with_row_offsets, self.row_offsets,
306+
self.pooling_modes, self.sgd_args, *self.weights
307+
)
308+
309+
@classmethod
310+
def from_embeddingbag_list(
311+
cls,
312+
tables: List[torch.nn.EmbeddingBag],
313+
lr: float = 0.01,
314+
weight_decay: float = 0
315+
):
316+
embedding_specs = []
317+
for emb in tables:
318+
emb_shape = emb.weight.shape
319+
embedding_specs.append(
320+
EmbeddingSpec(
321+
num_of_features=emb_shape[0],
322+
feature_size=emb_shape[1],
323+
pooling_modes=emb.mode,
324+
dtype=emb.weight.dtype,
325+
weight=emb.weight.detach()
326+
))
327+
return cls(embedding_specs, lr, weight_decay)

tests/cpu/bench/custom_op_bench/README.md

+12
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,15 @@ python -m intel_extension_for_pytorch.cpu.launch --socket_id 0 optimizer.py --op
3232
python -m intel_extension_for_pytorch.cpu.launch --socket_id 0 optimizer.py --optimizer lamb # for lamb
3333
python -m intel_extension_for_pytorch.cpu.launch --socket_id 0 optimizer.py --optimizer adagrad # for adagrad
3434
```
35+
36+
## Evaluate IPEX [MergedEmbeddingBag](../../../../intel_extension_for_pytorch/nn/module/merged_embeddingbag.py)
37+
```
38+
export CORES=`lscpu | grep Core | awk '{print $4}'`
39+
export BATCHSIZE=$((128*CORES))
40+
# Data distribution will not impact inference performance
41+
python -m intel_extension_for_pytorch.cpu.launch --socket_id 0 merged_embeddingbag.py --inference --data-distribution=balance --batch-size=${BATCHSIZE}
42+
43+
# For training, data distribution will have big impact while update weight. Under the "unbalance" arg, we will use generate datas with half of indice update same raw (which is similiar with real world dataset as DLRM mlperf dataset)
44+
python -m intel_extension_for_pytorch.cpu.launch --socket_id 0 merged_embeddingbag.py --data-distribution=balance --batch-size=${BATCHSIZE}
45+
python -m intel_extension_for_pytorch.cpu.launch --socket_id 0 merged_embeddingbag.py --data-distribution=unbalance --batch-size=${BATCHSIZE}
46+
```

0 commit comments

Comments
 (0)