|
| 1 | +import torch |
| 2 | +from torch import Tensor, nn |
| 3 | +from torch.autograd import Function |
| 4 | +from typing import List, Optional, NamedTuple |
| 5 | +from itertools import accumulate |
| 6 | +import enum |
| 7 | + |
| 8 | +class PoolingMode(enum.IntEnum): |
| 9 | + SUM = 0 |
| 10 | + MEAN = 1 |
| 11 | + |
| 12 | +class SGDArgs(NamedTuple): |
| 13 | + bf16_trail: List[Optional[torch.Tensor]] |
| 14 | + weight_decay: float |
| 15 | + lr: float |
| 16 | + |
| 17 | +class EmbeddingSpec(NamedTuple): |
| 18 | + num_of_features: int |
| 19 | + feature_size: int |
| 20 | + pooling_modes: str |
| 21 | + dtype: torch.dtype |
| 22 | + weight: Optional[torch.Tensor] |
| 23 | + |
| 24 | +def merged_embeddingbag_sgd( |
| 25 | + indices, |
| 26 | + offsets, |
| 27 | + indices_with_row_offsets, |
| 28 | + row_offsets, |
| 29 | + pooling_modes, |
| 30 | + sgd_args, |
| 31 | + *weights |
| 32 | +): |
| 33 | + if torch.is_grad_enabled(): |
| 34 | + return MergedEmbeddingBagSGDFunc.apply( |
| 35 | + indices, offsets, indices_with_row_offsets, row_offsets, pooling_modes, sgd_args, *weights |
| 36 | + ) |
| 37 | + return torch.ops.torch_ipex.merged_embeddingbag_forward(indices, offsets, weights, pooling_modes) |
| 38 | + |
| 39 | +class MergedEmbeddingBagSGDFunc(Function): |
| 40 | + @staticmethod |
| 41 | + def unpack(*args): |
| 42 | + return args |
| 43 | + |
| 44 | + @staticmethod |
| 45 | + def forward(ctx, indices, offsets, indices_with_row_offsets, row_offsets, pooling_modes, sgd_args, *weights): |
| 46 | + output = torch.ops.torch_ipex.merged_embeddingbag_forward( |
| 47 | + indices, offsets, weights, pooling_modes |
| 48 | + ) |
| 49 | + ctx.indices = indices |
| 50 | + ctx.offsets = offsets |
| 51 | + ctx.weights = weights |
| 52 | + ctx.indices_with_row_offsets = indices_with_row_offsets |
| 53 | + ctx.row_offsets = row_offsets |
| 54 | + ctx.pooling_modes = pooling_modes |
| 55 | + ctx.sgd_args = sgd_args |
| 56 | + return MergedEmbeddingBagSGDFunc.unpack(*output) |
| 57 | + |
| 58 | + @staticmethod |
| 59 | + def backward(ctx, *grad_out): |
| 60 | + indices = ctx.indices |
| 61 | + offsets = ctx.offsets |
| 62 | + weights = ctx.weights |
| 63 | + indices_with_row_offsets = ctx.indices_with_row_offsets |
| 64 | + row_offsets = ctx.row_offsets |
| 65 | + pooling_modes = ctx.pooling_modes |
| 66 | + sgd_args = ctx.sgd_args |
| 67 | + bf16_trail = sgd_args.bf16_trail |
| 68 | + weight_decay = sgd_args.weight_decay |
| 69 | + lr = sgd_args.lr |
| 70 | + torch.ops.torch_ipex.merged_embeddingbag_backward_sgd( |
| 71 | + grad_out, indices, offsets, weights, indices_with_row_offsets, |
| 72 | + row_offsets, pooling_modes, |
| 73 | + bf16_trail, weight_decay, lr) |
| 74 | + n_tables = len(weights) |
| 75 | + output = [None for i in range(n_tables + 6)] |
| 76 | + return MergedEmbeddingBagSGDFunc.unpack(*output) |
| 77 | + |
| 78 | +class MergedEmbeddingBag(nn.Module): |
| 79 | + r""" |
| 80 | + Merge multiple Pytorch EmbeddingBag (https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py#L221) |
| 81 | + as one torch.nn.Module. |
| 82 | + Native usage for multiple EmbeddingBag is: |
| 83 | + >>> EmbLists = torch.nn.Modulist(emb1, emb2, emb3, ..., emb_m) |
| 84 | + >>> inputs = [in1, in2, in3, ..., in_m] |
| 85 | + >>> outputs = [] |
| 86 | + >>> for i in range(len(EmbLists)): |
| 87 | + >>> outputs.append(Emb[in_i]) |
| 88 | + Our optimized path will be: |
| 89 | + >>> merged_emb = MergedEmbeddingBagWithSGD(args) |
| 90 | + >>> outputs = MergedEmbeddingBagWithSGD(input) |
| 91 | + We will have benefits for our optimized path: |
| 92 | + 1). We will save Pytorch OP dispatch overhead, if the EmbeddingBag operations are not |
| 93 | + heavy, this dispatch overhead will have big impact. |
| 94 | + We introduce "linearize_indices_and_offsets" step to merged indices/offsets together. But consider EmbeddingBags |
| 95 | + are usually the first layer of a model. So "linearize_indices_and_offsets" can be considered as "data prepocess" and |
| 96 | + can be done offline. |
| 97 | + This Module can not be used alone, we suggest to use MergedEmbeddingBagWith[Optimizer] instead. |
| 98 | + Now we can only choose MergedEmbeddingBagWithSGD and we plan to add more optimizer support |
| 99 | + in the future. |
| 100 | + For the introduction of MergedEmbeddingBagWith[Optimizer], please find the comments at |
| 101 | + MergedEmbeddingBagWithSGD. |
| 102 | + """ |
| 103 | + embedding_specs: List[EmbeddingSpec] |
| 104 | + |
| 105 | + def __init__( |
| 106 | + self, |
| 107 | + embedding_specs: List[EmbeddingSpec], |
| 108 | + ): |
| 109 | + super(MergedEmbeddingBag, self).__init__() |
| 110 | + self.n_tables = len(embedding_specs) |
| 111 | + self.weights = [] |
| 112 | + row_offsets = [] |
| 113 | + feature_sizes = [] |
| 114 | + self.pooling_modes = [] |
| 115 | + self.dtypes = [] |
| 116 | + dtype = None |
| 117 | + self.weights = torch.nn.ParameterList([nn.Parameter(torch.Tensor()) for i in range(len(embedding_specs))]) |
| 118 | + for i, emb in enumerate(embedding_specs): |
| 119 | + num_of_features, feature_size, mode, dtype, weight = emb |
| 120 | + row_offsets.append(num_of_features) |
| 121 | + if mode == 'sum': |
| 122 | + self.pooling_modes.append(PoolingMode.SUM) |
| 123 | + elif mode == 'mean': |
| 124 | + self.pooling_modes.append(PoolingMode.MEAN) |
| 125 | + else: |
| 126 | + assert False, r"MergedEmbeddingBag only support EmbeddingBag with model sum or mean" |
| 127 | + if weight is None: |
| 128 | + weight = torch.empty((num_of_features, feature_size), dtype=dtype) |
| 129 | + self.weights[i] = nn.Parameter(weight) |
| 130 | + self.register_buffer( |
| 131 | + "row_offsets", |
| 132 | + torch.tensor([0] + list(accumulate(row_offsets)), dtype=torch.int64), |
| 133 | + ) |
| 134 | + |
| 135 | + def extra_repr(self) -> str: |
| 136 | + s = 'number of tables={}\n'.format(self.n_tables) |
| 137 | + for i in range(self.n_tables): |
| 138 | + s += "table{}: {}, {}, {}, {}".format( |
| 139 | + i, self.weights[i].shape[0], self.weights[i].shape[1], self.pooling_modes[i], self.weights[i].dtype) |
| 140 | + if i != self.n_tables - 1: |
| 141 | + s += '\n' |
| 142 | + return s |
| 143 | + |
| 144 | + def linearize_indices_and_offsets( |
| 145 | + self, |
| 146 | + indices: List[Tensor], |
| 147 | + offsets: List[Optional[Tensor]], |
| 148 | + include_last_offsets: List[bool] |
| 149 | + ): |
| 150 | + r""" |
| 151 | + To make backward/update more balance, we only have 1 logical table in MergedEmbedingBag and |
| 152 | + use unified indices for access the whole logical table. |
| 153 | + We need to re-mark the indice from different tables to distinguish them. |
| 154 | + For example, we have 2 tables with shape [200, 128] and [100, 128]. |
| 155 | + The indice 50 for table1 is still 50 and the indice 50 for table2 should be set to 50 + 200 = 250. |
| 156 | + We assume the original indice and offset will follow the usage for Pytorch EmbeddingBag: |
| 157 | + https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py#L355-L382 |
| 158 | + """ |
| 159 | + # TODO: support per_sample_weights in forward |
| 160 | + def get_batch_size(indice, offset, include_last_offset): |
| 161 | + if indice.dim() == 2: |
| 162 | + assert offset is None, "offset should be None if indice is 2-D tensor, https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py#L355-L382" |
| 163 | + batch_size = indice.shape[0] |
| 164 | + else: |
| 165 | + batch_size = offset.numel() |
| 166 | + if include_last_offset: |
| 167 | + batch_size -= 1 |
| 168 | + return batch_size |
| 169 | + |
| 170 | + assert self.n_tables == len(indices), "expected {} but got {} indices".format(self.n_tables, len(indices)) |
| 171 | + assert self.n_tables == len(offsets), "expected {} but got {} offsets".format(self.n_tables, len(offsets)) |
| 172 | + assert self.n_tables == len(include_last_offsets), "expected {} but got {} include_last_offsets".format( |
| 173 | + self.n_tables, len(include_last_offsets)) |
| 174 | + |
| 175 | + batch_size = get_batch_size(indices[0], offsets[0], include_last_offsets[0]) |
| 176 | + assert all( |
| 177 | + batch_size == get_batch_size(idx, offset, include_last) for idx, offset, include_last in zip(indices, offsets, include_last_offsets) |
| 178 | + ), r"MergedEmbeddingBag only support input with same batch size" |
| 179 | + n_indices = sum([t.numel() for t in indices]) |
| 180 | + n_offsets = batch_size * self.n_tables + 1 # include last offset |
| 181 | + merged_indices = torch.empty(n_indices, dtype=torch.int64) |
| 182 | + merged_indices_with_row_offsets = torch.empty(n_indices, dtype=torch.int64) # used for sort together |
| 183 | + merged_offsets = torch.empty(n_offsets, dtype=torch.int64) |
| 184 | + idx_start = 0 |
| 185 | + offset_start = 0 |
| 186 | + for i in range(self.n_tables): |
| 187 | + n_indice = indices[i].numel() |
| 188 | + merged_indices[idx_start: idx_start + n_indice].copy_(indices[i].view(-1)) |
| 189 | + merged_indices_with_row_offsets[idx_start: idx_start + n_indice].copy_(indices[i].view(-1) + self.row_offsets[i]) |
| 190 | + if indices[i].dim() == 2: |
| 191 | + bag_size = indices[i].shape[1] |
| 192 | + offset = torch.arange(0, indices[i].numel(), bag_size) |
| 193 | + else: |
| 194 | + offset = offsets[i][:-1] if include_last_offsets[i] else offsets[i] |
| 195 | + assert offset.numel() == batch_size |
| 196 | + merged_offsets[offset_start : offset_start + batch_size].copy_(offset + idx_start) |
| 197 | + idx_start += n_indice |
| 198 | + offset_start += batch_size |
| 199 | + assert idx_start == n_indices |
| 200 | + assert offset_start == n_offsets - 1 |
| 201 | + merged_offsets[-1] = n_indices |
| 202 | + return (merged_indices, merged_offsets, merged_indices_with_row_offsets) |
| 203 | + |
| 204 | + def forward(self, input, need_linearize_indices_and_offsets=torch.BoolTensor([True])): |
| 205 | + assert False, "Please use MergedEmbeddingBagWith[Optimizer]. We only support SGD now, so please create module MergedEmbeddingBagWithSGD instead" |
| 206 | + |
| 207 | + |
| 208 | +class MergedEmbeddingBagWithSGD(MergedEmbeddingBag): |
| 209 | + r""" |
| 210 | + To support training for MergedEmbeddingBag with good performance, we fused optimizer step |
| 211 | + with backward function. |
| 212 | + Native usage for multiple EmbeddingBag is: |
| 213 | + >>> EmbLists = torch.nn.Modulist(emb1, emb2, emb3, ..., emb_m) |
| 214 | + >>> sgd = torch.optim.SGD(EmbLists.parameters(), lr=lr, weight_decay=weight_decay) |
| 215 | + >>> inputs = [in1, in2, in3, ..., in_m] |
| 216 | + >>> outputs = [] |
| 217 | + >>> for i in range(len(EmbLists)): |
| 218 | + >>> outputs.append(Emb[in_i]) |
| 219 | + >>> sgd.zero_grad() |
| 220 | + >>> for i in range(len(outputs)): |
| 221 | + >>> out.backward(grads[i]) |
| 222 | + >>> sgd.step() |
| 223 | + Our optimized path will be: |
| 224 | + >>> # create MergedEmbeddingBagWithSGD module with optimizer args (lr and weight decay) |
| 225 | + >>> merged_emb = MergedEmbeddingBagWithSGD(args) |
| 226 | + >>> merged_input = merged_emb.linearize_indices_and_offsets(inputs) |
| 227 | + >>> outputs = MergedEmbeddingBagWithSGD(merged_input) |
| 228 | + >>> outputs.backward(grads) |
| 229 | + We will get further benefits in training: |
| 230 | + 1). We will futher save Pytorch OP dispatch overhead in backward and weight update process. |
| 231 | + 2). We will make thread loading more balance during backward/weight update. In real |
| 232 | + world scenario, Embedingbag are often used to represent categorical features and the |
| 233 | + categorical features will often fit power law distribution. For example, if we use one |
| 234 | + Embeddingtable to represent the age range the users of a video game website. We might |
| 235 | + find most of users are from 10-19 or 20-29. So we may need update the row which represent |
| 236 | + 10-19 or 20-29 frequently. Since update these rows need to write at the same memory address, |
| 237 | + we need to write it by 1 thread (or we will have write conflict or have overhead to solve the conflict). |
| 238 | + By merge multiple table together, we will have more friendly distribution to distribute |
| 239 | + backward/update tasks. |
| 240 | + 3). We will fuse update with backward together. We can immediately update the weight after |
| 241 | + we get grad from backward thus the memory pattern will be more friendly. We will have |
| 242 | + more chance to access data from cache. |
| 243 | + """ |
| 244 | + embedding_specs: List[EmbeddingSpec] |
| 245 | + |
| 246 | + def __init__( |
| 247 | + self, |
| 248 | + embedding_specs: List[EmbeddingSpec], |
| 249 | + lr: float = 0.01, |
| 250 | + weight_decay: float = 0 |
| 251 | + ): |
| 252 | + super(MergedEmbeddingBagWithSGD, self).__init__(embedding_specs) |
| 253 | + self.sgd_args = self.init_sgd_args(lr, weight_decay) |
| 254 | + for i in range(self.n_tables): |
| 255 | + weight = self.weights[i] |
| 256 | + if weight.dtype == torch.bfloat16: |
| 257 | + self.sgd_args.bf16_trail.append(torch.zeros_like(weight, dtype=torch.bfloat16)) |
| 258 | + else: |
| 259 | + self.sgd_args.bf16_trail.append(torch.empty(0, dtype=torch.bfloat16)) |
| 260 | + |
| 261 | + def init_sgd_args(self, lr, weight_decay, bf16_trail=[]): |
| 262 | + if lr < 0.0: |
| 263 | + raise ValueError("Invalid learning rate: {}".format(lr)) |
| 264 | + if weight_decay < 0.0: |
| 265 | + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) |
| 266 | + return SGDArgs( |
| 267 | + weight_decay=weight_decay, |
| 268 | + lr=lr, |
| 269 | + bf16_trail=bf16_trail |
| 270 | + ) |
| 271 | + |
| 272 | + def to_bfloat16_train(self): |
| 273 | + r""" |
| 274 | + Cast weight to bf16 and it's trail part for training |
| 275 | + """ |
| 276 | + trails = [] |
| 277 | + for i in range(len(self.weights)): |
| 278 | + if self.weights[i].dtype == torch.float: |
| 279 | + bf16_w, trail = torch.ops.torch_ipex.split_float_bfloat16(self.weights[i]) |
| 280 | + elif self.weights[i].dtype == torch.bfloat16: |
| 281 | + bf16_w = self.weights[i] |
| 282 | + trail = torch.zeros_like(bf16_w, dtype=torch.bfloat16) |
| 283 | + elif self.weights[i].dtype == torch.double: |
| 284 | + bf16_w, trail = torch.ops.torch_ipex.split_float_bfloat16(self.weights[i].float()) |
| 285 | + else: |
| 286 | + assert False, r"MergedEmbeddingBag only support dtypes with bfloat, float and double" |
| 287 | + trails.append(trail) |
| 288 | + self.weights[i] = torch.nn.Parameter(bf16_w) |
| 289 | + self.sgd_args = self.sgd_args._replace(bf16_trail=trails) |
| 290 | + |
| 291 | + def forward(self, input, need_linearize_indices_and_offsets=torch.BoolTensor([True])): |
| 292 | + r""" |
| 293 | + Args: |
| 294 | + input (Tuple[Tensor]): a tuple of (indices, offsets, include_last_offsets(if not merged)/indices_with_row_offsets(if merged)) |
| 295 | + need_linearize_indices_and_offsets: indicate whether input need to be linearized |
| 296 | + Returns: |
| 297 | + List[Tensor] output shape of `(batch_size, feature_size)` which length = num of tables. |
| 298 | + """ |
| 299 | + if need_linearize_indices_and_offsets.item(): |
| 300 | + indices, offsets, include_last_offsets = input |
| 301 | + indices, offsets, indices_with_row_offsets = self.linearize_indices_and_offsets(indices, offsets, include_last_offsets) |
| 302 | + else: |
| 303 | + indices, offsets, indices_with_row_offsets = input |
| 304 | + return merged_embeddingbag_sgd( |
| 305 | + indices, offsets, indices_with_row_offsets, self.row_offsets, |
| 306 | + self.pooling_modes, self.sgd_args, *self.weights |
| 307 | + ) |
| 308 | + |
| 309 | + @classmethod |
| 310 | + def from_embeddingbag_list( |
| 311 | + cls, |
| 312 | + tables: List[torch.nn.EmbeddingBag], |
| 313 | + lr: float = 0.01, |
| 314 | + weight_decay: float = 0 |
| 315 | + ): |
| 316 | + embedding_specs = [] |
| 317 | + for emb in tables: |
| 318 | + emb_shape = emb.weight.shape |
| 319 | + embedding_specs.append( |
| 320 | + EmbeddingSpec( |
| 321 | + num_of_features=emb_shape[0], |
| 322 | + feature_size=emb_shape[1], |
| 323 | + pooling_modes=emb.mode, |
| 324 | + dtype=emb.weight.dtype, |
| 325 | + weight=emb.weight.detach() |
| 326 | + )) |
| 327 | + return cls(embedding_specs, lr, weight_decay) |
0 commit comments