Skip to content

Commit

Permalink
Fix an issue applying momentum to narrow embedding layers in SGD/SGDW
Browse files Browse the repository at this point in the history
With embedding layers where one of the dimensions has size 1 (e.g. a single embedding used to represent biases), squeezing the momentum values was removing the size 1 dimension entirely. This adds `reshape` to make the indices and values of the constructed sparse momentum tensor match with each other.
  • Loading branch information
karlhigley committed Oct 3, 2020
1 parent 61ca507 commit b6dd32f
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "torch-optim-sparse"
version = "0.1.2"
description = "Truly sparse optimizers for PyTorch"
version = "0.1.3"
description = "PyTorch optimizers with sparse momentum and weight decay"
authors = ["Karl Higley <[email protected]>"]

[tool.poetry.dependencies]
Expand Down
2 changes: 1 addition & 1 deletion torch_optim_sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.1.2'
__version__ = '0.1.3'

from .sparser_adam import SparserAdam
from .sparser_adamw import SparserAdamW
Expand Down
2 changes: 1 addition & 1 deletion torch_optim_sparse/sparser_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def make_sparse(values):
constructor = grad.new
if grad_inds.dim() == 0 or values.dim() == 0:
return constructor().resize_as_(grad)
return constructor(grad_inds, values, size)
return constructor(grad_inds, values.reshape(grad_values.shape), size)

if weight_decay != 0:
param_values = p.data[grad_inds].squeeze()
Expand Down
2 changes: 1 addition & 1 deletion torch_optim_sparse/sparser_sgdw.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def make_sparse(values):
constructor = grad.new
if grad_inds.dim() == 0 or values.dim() == 0:
return constructor().resize_as_(grad)
return constructor(grad_inds, values, size)
return constructor(grad_inds, values.reshape(grad_values.shape), size)

if momentum != 0:
param_state = self.state[p]
Expand Down

0 comments on commit b6dd32f

Please sign in to comment.