Skip to content

Commit

Permalink
Open-sourced update on 12/05/2024 (#58)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #58

Bug fix on `MASKED_FILTERED_GRAD_LIST` and `MASKED_MOMENTUM_LIST` updates.

Reviewed By: chuanhaozhuge

Differential Revision: D66826053

fbshipit-source-id: 815410c56835183955768ff2b4bbbeb98b4865b7
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Dec 9, 2024
1 parent 46fdd40 commit f9d2a8c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
6 changes: 2 additions & 4 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,14 +777,12 @@ def _mask_state_lists(state_lists: dict[str, Any], group: dict[str, Any]) -> Non
)
if group[BETAS][0] != 0.0:
state_lists[MASKED_FILTERED_GRAD_LIST] = state_lists[
MASKED_FILTERED_GRAD_LIST
FILTERED_GRAD_LIST
].compress(
state_lists[DISTRIBUTOR].local_grad_selector,
)
if group[MOMENTUM] != 0.0:
state_lists[MASKED_MOMENTUM_LIST] = state_lists[
MASKED_MOMENTUM_LIST
].compress(
state_lists[MASKED_MOMENTUM_LIST] = state_lists[MOMENTUM_LIST].compress(
state_lists[DISTRIBUTOR].local_grad_selector,
)

Expand Down
1 change: 1 addition & 0 deletions distributed_shampoo/gpu_tests/shampoo_grafting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def test_rmsprop_grafting_on_quadratic(self) -> None:
max_preconditioner_dim=10,
precondition_frequency=1,
start_preconditioning_step=math.inf,
use_bias_correction=False,
use_decoupled_weight_decay=False,
grafting_config=RMSpropGraftingConfig(
beta2=0.99,
Expand Down

0 comments on commit f9d2a8c

Please sign in to comment.