diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index ef8e2ce..108ab3e 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -777,14 +777,12 @@ def _mask_state_lists(state_lists: dict[str, Any], group: dict[str, Any]) -> Non ) if group[BETAS][0] != 0.0: state_lists[MASKED_FILTERED_GRAD_LIST] = state_lists[ - MASKED_FILTERED_GRAD_LIST + FILTERED_GRAD_LIST ].compress( state_lists[DISTRIBUTOR].local_grad_selector, ) if group[MOMENTUM] != 0.0: - state_lists[MASKED_MOMENTUM_LIST] = state_lists[ - MASKED_MOMENTUM_LIST - ].compress( + state_lists[MASKED_MOMENTUM_LIST] = state_lists[MOMENTUM_LIST].compress( state_lists[DISTRIBUTOR].local_grad_selector, ) diff --git a/distributed_shampoo/gpu_tests/shampoo_grafting_test.py b/distributed_shampoo/gpu_tests/shampoo_grafting_test.py index 1c5907a..1f3bc05 100644 --- a/distributed_shampoo/gpu_tests/shampoo_grafting_test.py +++ b/distributed_shampoo/gpu_tests/shampoo_grafting_test.py @@ -180,6 +180,7 @@ def test_rmsprop_grafting_on_quadratic(self) -> None: max_preconditioner_dim=10, precondition_frequency=1, start_preconditioning_step=math.inf, + use_bias_correction=False, use_decoupled_weight_decay=False, grafting_config=RMSpropGraftingConfig( beta2=0.99,