@@ -367,9 +367,30 @@ def _aten_bmm(x, y):
367
367
368
368
@op (torch .ops .aten .embedding )
369
369
# embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False)
370
- def _aten_embedding (a , w , padding_idx = - 1 ):
370
+ def _aten_embedding (a , w , padding_idx = - 1 , scale_grad_by_freq = False , sparse = False ):
371
371
return jnp .take (a , w , axis = 0 )
372
372
373
+ @op (torch .ops .aten .embedding_renorm_ )
374
+ def _aten_embedding_renorm_ (weight , indices , max_norm , norm_type ):
375
+ # Adapted from https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Embedding.cpp
376
+ unique_indices = jnp .unique (indices )
377
+
378
+ norm = jnp .linalg .norm (
379
+ _aten_embedding (weight , unique_indices ),
380
+ ord = norm_type ,
381
+ axis = 1 ,
382
+ )
383
+
384
+ indice_idx = jnp .where (norm > max_norm )
385
+
386
+ scale = max_norm / (norm [indice_idx ] + 1e-7 )
387
+
388
+ indices_to_update = unique_indices [indice_idx ]
389
+
390
+ weight = weight .at [indices_to_update ].set (
391
+ weight [indices_to_update ] * scale [:, None ]
392
+ )
393
+ return weight
373
394
374
395
#- func: _embedding_bag_forward_only(
375
396
# Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False,
0 commit comments