Skip to content

Commit

Permalink
add util function to compute rowwise adagrad updates
Browse files Browse the repository at this point in the history
Summary:
Added `compute_rowwise_adagrad_updates`, which is a util function to compute rowwise adagrad if we want to just pass in optim_state and grad, without paramater.

It can handle the case when grad is sparse.

Differential Revision: D58270549
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Jun 20, 2024
1 parent ac33f23 commit cbb098f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
15 changes: 7 additions & 8 deletions torchrec/optim/rowwise_adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,6 @@ def __init__(
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))

if weight_decay > 0:
logger.warning(
"Note that the weight decay mode of this optimizer may produce "
"different results compared to the one by FBGEMM TBE. This is "
"due to FBGEMM TBE rowwise adagrad is sparse, and will only "
"update the optimizer states if that row has nonzero gradients."
)

defaults = dict(
lr=lr,
lr_decay=lr_decay,
Expand Down Expand Up @@ -213,6 +205,13 @@ def _single_tensor_adagrad(
eps: float,
maximize: bool,
) -> None:
if weight_decay != 0 and len(state_steps) > 0 and state_steps[0].item() < 1.0:
logger.warning(
"Note that the weight decay mode of this optimizer may produce "
"different results compared to the one by FBGEMM TBE. This is "
"due to FBGEMM TBE rowwise adagrad is sparse, and will only "
"update the optimizer states if that row has nonzero gradients."
)

for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps):
if grad.is_sparse:
Expand Down
13 changes: 10 additions & 3 deletions torchrec/optim/tests/test_rowwise_adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,24 @@

class RowWiseAdagradTest(unittest.TestCase):
def test_optim(self) -> None:
embedding_bag = torch.nn.EmbeddingBag(num_embeddings=4, embedding_dim=4)
embedding_bag = torch.nn.EmbeddingBag(
num_embeddings=4, embedding_dim=4, mode="sum"
)
opt = torchrec.optim.RowWiseAdagrad(embedding_bag.parameters())
index, offsets = torch.tensor([0, 3]), torch.tensor([0, 1])
embedding_bag_out = embedding_bag(index, offsets)
opt.zero_grad()
embedding_bag_out.sum().backward()
opt.step()

def test_optim_equivalence(self) -> None:
# If rows are initialized to be the same and uniform, then RowWiseAdagrad and canonical Adagrad are identical
rowwise_embedding_bag = torch.nn.EmbeddingBag(num_embeddings=4, embedding_dim=4)
embedding_bag = torch.nn.EmbeddingBag(num_embeddings=4, embedding_dim=4)
rowwise_embedding_bag = torch.nn.EmbeddingBag(
num_embeddings=4, embedding_dim=4, mode="sum"
)
embedding_bag = torch.nn.EmbeddingBag(
num_embeddings=4, embedding_dim=4, mode="sum"
)
state_dict = {
"weight": torch.Tensor(
[[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]]
Expand Down

0 comments on commit cbb098f

Please sign in to comment.