Skip to content

Commit 5872b20

Browse files
authored
Implement torch.ops.aten.embedding_renorm_ (#8091)
1 parent ecc0f5a commit 5872b20

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

experimental/torch_xla2/test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@
9696
"nn.functional.dropout3d",
9797
"nn.functional.dropout",
9898
"nn.functional.embedding_bag",
99-
"nn.functional.embedding",
10099
"nn.functional.fractional_max_pool2d",
101100
"nn.functional.fractional_max_pool3d",
102101
"nn.functional.group_norm",

experimental/torch_xla2/torch_xla2/ops/jaten.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,30 @@ def _aten_bmm(x, y):
367367

368368
@op(torch.ops.aten.embedding)
369369
# embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False)
370-
def _aten_embedding(a, w, padding_idx=-1):
370+
def _aten_embedding(a, w, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
371371
return jnp.take(a, w, axis=0)
372372

373+
@op(torch.ops.aten.embedding_renorm_)
374+
def _aten_embedding_renorm_(weight, indices, max_norm, norm_type):
375+
# Adapted from https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Embedding.cpp
376+
unique_indices = jnp.unique(indices)
377+
378+
norm = jnp.linalg.norm(
379+
_aten_embedding(weight, unique_indices),
380+
ord=norm_type,
381+
axis=1,
382+
)
383+
384+
indice_idx = jnp.where(norm > max_norm)
385+
386+
scale = max_norm / (norm[indice_idx] + 1e-7)
387+
388+
indices_to_update = unique_indices[indice_idx]
389+
390+
weight = weight.at[indices_to_update].set(
391+
weight[indices_to_update] * scale[:, None]
392+
)
393+
return weight
373394

374395
#- func: _embedding_bag_forward_only(
375396
# Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False,

0 commit comments

Comments
 (0)