From cbb098f0db0da1f9f33a8bdd33e72b6da571f388 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Thu, 20 Jun 2024 13:11:18 -0700 Subject: [PATCH] add util function to compute rowwise adagrad updates Summary: Added `compute_rowwise_adagrad_updates`, which is a util function to compute rowwise adagrad if we want to just pass in optim_state and grad, without paramater. It can handle the case when grad is sparse. Differential Revision: D58270549 --- torchrec/optim/rowwise_adagrad.py | 15 +++++++-------- torchrec/optim/tests/test_rowwise_adagrad.py | 13 ++++++++++--- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/torchrec/optim/rowwise_adagrad.py b/torchrec/optim/rowwise_adagrad.py index a037d0404..6712b115e 100644 --- a/torchrec/optim/rowwise_adagrad.py +++ b/torchrec/optim/rowwise_adagrad.py @@ -69,14 +69,6 @@ def __init__( if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) - if weight_decay > 0: - logger.warning( - "Note that the weight decay mode of this optimizer may produce " - "different results compared to the one by FBGEMM TBE. This is " - "due to FBGEMM TBE rowwise adagrad is sparse, and will only " - "update the optimizer states if that row has nonzero gradients." - ) - defaults = dict( lr=lr, lr_decay=lr_decay, @@ -213,6 +205,13 @@ def _single_tensor_adagrad( eps: float, maximize: bool, ) -> None: + if weight_decay != 0 and len(state_steps) > 0 and state_steps[0].item() < 1.0: + logger.warning( + "Note that the weight decay mode of this optimizer may produce " + "different results compared to the one by FBGEMM TBE. This is " + "due to FBGEMM TBE rowwise adagrad is sparse, and will only " + "update the optimizer states if that row has nonzero gradients." + ) for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps): if grad.is_sparse: diff --git a/torchrec/optim/tests/test_rowwise_adagrad.py b/torchrec/optim/tests/test_rowwise_adagrad.py index 7166b9c03..3a221b412 100644 --- a/torchrec/optim/tests/test_rowwise_adagrad.py +++ b/torchrec/optim/tests/test_rowwise_adagrad.py @@ -16,17 +16,24 @@ class RowWiseAdagradTest(unittest.TestCase): def test_optim(self) -> None: - embedding_bag = torch.nn.EmbeddingBag(num_embeddings=4, embedding_dim=4) + embedding_bag = torch.nn.EmbeddingBag( + num_embeddings=4, embedding_dim=4, mode="sum" + ) opt = torchrec.optim.RowWiseAdagrad(embedding_bag.parameters()) index, offsets = torch.tensor([0, 3]), torch.tensor([0, 1]) embedding_bag_out = embedding_bag(index, offsets) opt.zero_grad() embedding_bag_out.sum().backward() + opt.step() def test_optim_equivalence(self) -> None: # If rows are initialized to be the same and uniform, then RowWiseAdagrad and canonical Adagrad are identical - rowwise_embedding_bag = torch.nn.EmbeddingBag(num_embeddings=4, embedding_dim=4) - embedding_bag = torch.nn.EmbeddingBag(num_embeddings=4, embedding_dim=4) + rowwise_embedding_bag = torch.nn.EmbeddingBag( + num_embeddings=4, embedding_dim=4, mode="sum" + ) + embedding_bag = torch.nn.EmbeddingBag( + num_embeddings=4, embedding_dim=4, mode="sum" + ) state_dict = { "weight": torch.Tensor( [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]]