Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor config structure #60

Closed
wants to merge 15 commits into from
Closed
15 changes: 9 additions & 6 deletions distributed_shampoo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ A few notes on hyperparameters:

- We allow for decoupled and coupled weight decay. If one sets `use_decoupled_weight_decay=True`, then you are enabling AdamW-style weight decay, while `use_decoupled_weight_decay=False` corresponds to the normal L2-regularization style weight decay.

- When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectionConfig`, there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial.
- When setting `preconditioner_config` as an instance of `EigenvalueCorrectedShampooPreconditionerConfig` (see Example 5), there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial.

### Example 1: [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html) with Momentum

Expand Down Expand Up @@ -221,7 +221,7 @@ optimizer = DistributedShampoo(
)
```

### Example 5: eigenvalue-corrected Shampoo (SOAP)
### Example 5: eigenvalue-corrected Shampoo/SOAP

If we previously used the optimizer:
```python
Expand All @@ -241,7 +241,10 @@ optimizer = AdamW(
we would instead use:
```python
import torch
from distributed_shampoo import DistributedShampoo, EighEigenvalueCorrectionConfig
from distributed_shampoo import (
DistributedShampoo,
DefaultEigenvalueCorrectedShampooConfig,
)

model = instantiate_model()

Expand All @@ -254,9 +257,9 @@ optimizer = DistributedShampoo(
max_preconditioner_dim=8192,
precondition_frequency=100,
use_decoupled_weight_decay=True,
# This can also be set to `QREigenvalueCorrectionConfig` which is less expensive
# and might therefore allow for a smaller `precondition_frequency`.
preconditioner_computation_config=EighEigenvalueCorrectionConfig(),
# This can also be set to `DefaultSOAPConfig` which uses QR decompositions, hence is
# less expensive and might thereby allow for a smaller `precondition_frequency`.
preconditioner_config=DefaultEigenvalueCorrectedShampooConfig,
)
```

Expand Down
42 changes: 26 additions & 16 deletions distributed_shampoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,30 @@
AdamGraftingConfig,
CommunicationDType,
DDPShampooConfig,
DefaultEigenvalueCorrectedShampooConfig,
DefaultShampooConfig,
DefaultSOAPConfig,
DistributedConfig,
EigenvalueCorrectedShampooPreconditionerConfig,
FSDPShampooConfig,
FullyShardShampooConfig,
GraftingConfig,
HSDPShampooConfig,
PrecisionConfig,
PreconditionerConfig,
RMSpropGraftingConfig,
SGDGraftingConfig,
ShampooPreconditionerConfig,
ShampooPT2CompileConfig,
)
from distributed_shampoo.utils.shampoo_fsdp_utils import compile_fsdp_parameter_metadata
from matrix_functions_types import (
CoupledHigherOrderConfig,
CoupledNewtonConfig,
DefaultEigenConfig,
DefaultEighEigenvalueCorrectionConfig,
EigenConfig,
EigenvalueCorrectionConfig,
EighEigenvalueCorrectionConfig,
PreconditionerComputationConfig,
QREigenvalueCorrectionConfig,
EigenvectorConfig,
MatrixFunctionConfig,
RootInvConfig,
)

Expand All @@ -56,17 +59,24 @@
"HSDPShampooConfig",
# `precision_config`.
"PrecisionConfig",
# `preconditioner_computation_config` options.
"PreconditionerComputationConfig", # Abstract base class.
"RootInvConfig", # Abstract base class (based on `PreconditionerComputationConfig`).
"EigenConfig",
"DefaultEigenConfig", # Default `RootInvConfig`.
"CoupledNewtonConfig",
"CoupledHigherOrderConfig",
"EigenvalueCorrectionConfig", # Abstract base class (based on `PreconditionerComputationConfig`).
"EighEigenvalueCorrectionConfig",
"DefaultEighEigenvalueCorrectionConfig", # Default `EigenvalueCorrectionConfig`.
"QREigenvalueCorrectionConfig",
# `preconditioner_config` options.
"PreconditionerConfig", # Abstract base class.
"ShampooPreconditionerConfig", # Based on `PreconditionerConfig`.
"DefaultShampooConfig", # Default `ShampooPreconditionerConfig` using `EigenConfig`.
"EigenvalueCorrectedShampooPreconditionerConfig", # Based on `PreconditionerConfig`.
"DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighEigenvectorConfig`.
"DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QRConfig`.
# matrix functions configs.
"MatrixFunctionConfig", # Abstract base class.
"RootInvConfig", # Abstract base class (based on `MatrixFunctionConfig`).
"EigenConfig", # Based on `RootInvConfig`.
"DefaultEigenConfig", # Default `RootInvConfig` using `EigenConfig`.
"CoupledNewtonConfig", # Based on `RootInvConfig`.
"CoupledHigherOrderConfig", # Based on `RootInvConfig`.
"EigenvectorConfig", # Abstract base class (based on `MatrixFunctionConfig`).
"EighEigenvectorConfig", # Based on `EigenvectorConfig`.
"DefaultEighEigenvectorConfig", # Default `EigenvectorConfig` using `EighEigenvectorConfig`.
"QRConfig", # Based on `EigenvectorConfig`.
# Other utilities.
"compile_fsdp_parameter_metadata", # For `FSDPShampooConfig` and `HSDPShampooConfig`.
"CommunicationDType", # For `DDPShampooConfig` and `HSDPShampooConfig`.
Expand Down
75 changes: 39 additions & 36 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
BETAS,
DAMPENING,
DDPShampooConfig,
DefaultShampooConfig,
DistributedConfig,
DISTRIBUTOR,
EigenvalueCorrectedShampooPreconditionerConfig,
EPSILON,
FILTERED_GRAD,
FILTERED_GRAD_LIST,
Expand All @@ -48,11 +50,13 @@
PRECISION_CONFIG,
PrecisionConfig,
PRECONDITION_FREQUENCY,
PRECONDITIONER_COMPUTATION_CONFIG,
PRECONDITIONER_CONFIG,
PreconditionerConfig,
PREVIOUS_GRAD_SELECTOR,
RMSpropGraftingConfig,
SGDGraftingConfig,
SHAMPOO_PRECONDITIONER_LIST,
ShampooPreconditionerConfig,
ShampooPT2CompileConfig,
START_PRECONDITIONING_STEP,
STEP,
Expand Down Expand Up @@ -91,13 +95,7 @@
)
from distributed_shampoo.utils.shampoo_utils import compress_list

from matrix_functions_types import (
DefaultEigenConfig,
EigenConfig,
EigenvalueCorrectionConfig,
PreconditionerComputationConfig,
RootInvConfig,
)
from matrix_functions_types import EigenConfig, RootInvConfig
from torch.optim.optimizer import ParamsT, StateDict

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -216,7 +214,7 @@ class DistributedShampoo(torch.optim.Optimizer):
updated every iteration while the eigenbasis of Shampoo's preconditioner is only computed every `precondition_frequency` steps.
Alternatively, this can be seen as running Adam in the eigenbasis of Shampoo's preconditioner, also known as SOAP.

When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectionConfig`, there is typically no need to use learning
When setting `preconditioner_config` as an instance of `EigenvalueCorrectedShampooPreconditionerConfig`, there is typically no need to use learning
rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be
a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet.
Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial.
Expand All @@ -236,16 +234,16 @@ class DistributedShampoo(torch.optim.Optimizer):
weight_decay (float): Weight decay (L2 penalty). (Default: 0.)
max_preconditioner_dim (int): Maximum preconditioner dimension. (Default: 1024)
precondition_frequency (int): Frequency of updating all components of the preconditioner.
If this field is an instance RootInvConfig, this is the update frequency of the root inverse of the preconditioner.
If this field is an instance EigenvalueCorrectionConfig, this is the update frequency of the eigenbasis of preconditioner.
If this field is an instance ShampooPreconditionerConfig, this is the update frequency of the root inverse of the preconditioner.
If this field is an instance EigenvalueCorrectedShampooPreconditionerConfig, this is the update frequency of the eigenbasis of preconditioner.
(Default: 1)
start_preconditioning_step (int): Iteration to start computing inverse preconditioner. If -1, uses
the same value as precondition_frequency. (Default: -1)
inv_root_override (int, Sequence[int]): Inverse root to use in Shampoo. If a list [l1, l2, ..., lp], then we will
use -1 / l1 for 1-D tensor (vectors), -1 / l2 for 2-D tensors (matrices), and so on. If the order of the
tensor exceeds the order of the tensor, reverts to the default value. If 0 is used, uses the default inverse
root -1 / (2 * o), where o is the order of the tensor. If preconditioner_computation_config is an instance of
EigenvalueCorrectionConfig, the default is -1 / 2.
root -1 / (2 * o), where o is the order of the tensor. If preconditioner_config is an instance of
EigenvalueCorrectedShampooPreconditionerConfig, the default is -1 / 2.
(Default: 0)
exponent_multiplier (float | None): **DEPRECATING** Number to be multiplied to the numerator of the inverse root, i.e., eta where the
exponent is -eta / (2 * p). (Default: None)
Expand All @@ -271,10 +269,10 @@ class DistributedShampoo(torch.optim.Optimizer):
3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail.
track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes.
(Default: False)
preconditioner_computation_config (PreconditionerComputationConfig): Configuration for preconditioner computation.
If this field is an instance RootInvConfig, Shampoo uses the root inverse of the preconditioner.
If this field is an instance EigenvalueCorrectionConfig Shampoo uses corrected the eigenvalues/running Adam in the eigenbasis of preconditioner.
(Default: DefaultEigenConfig)
preconditioner_config (PreconditionerConfig): Configuration for preconditioner computation.
If this field is an instance ShampooPreconditionerConfig, Shampoo uses the root inverse of the preconditioner.
If this field is an instance EigenvalueCorrectedShampooPreconditionerConfig Shampoo uses corrected the eigenvalues/running Adam in the eigenbasis of preconditioner.
(Default: DefaultShampooConfig)

"""

Expand Down Expand Up @@ -305,7 +303,7 @@ def __init__(
precision_config: PrecisionConfig | None = None,
use_protected_eigh: bool = True,
track_root_inv_residuals: bool = False,
preconditioner_computation_config: PreconditionerComputationConfig = DefaultEigenConfig,
preconditioner_config: PreconditionerConfig = DefaultShampooConfig,
) -> None:
# Hyperparameter checks.
if not lr >= 0.0:
Expand Down Expand Up @@ -419,11 +417,14 @@ def __init__(
"Both preconditioner_dtype and precision_config are provided. Please use only precision_config as preconditioner_dtype is deprecated."
)

amortized_computation_config = (
preconditioner_config.amortized_computation_config
)
if (
not isinstance(preconditioner_computation_config, RootInvConfig)
not isinstance(amortized_computation_config, RootInvConfig)
) and track_root_inv_residuals:
raise ValueError(
f"{track_root_inv_residuals=} has to be set to False when {preconditioner_computation_config=} is not an instance of RootInvConfig."
f"{track_root_inv_residuals=} has to be set to False when {amortized_computation_config=} is not an instance of RootInvConfig."
)

# Create default precision config if it is not provided.
Expand All @@ -432,14 +433,14 @@ def __init__(

# Set exponent multiplier if this is not provided.
if (
isinstance(preconditioner_computation_config, EigenConfig)
isinstance(amortized_computation_config, EigenConfig)
and exponent_multiplier is not None
):
logger.warning(
f"{exponent_multiplier=} is deprecating. Please consider using EigenConfig.exponent_multiplier directly and setting exponent_multipler=None instead in the future."
)
preconditioner_computation_config = dataclasses.replace(
preconditioner_computation_config,
amortized_computation_config = dataclasses.replace(
amortized_computation_config,
exponent_multiplier=exponent_multiplier,
)

Expand All @@ -463,7 +464,7 @@ def __init__(
GRAFTING_CONFIG: grafting_config,
USE_MERGE_DIMS: use_merge_dims,
PRECISION_CONFIG: precision_config,
PRECONDITIONER_COMPUTATION_CONFIG: preconditioner_computation_config,
PRECONDITIONER_CONFIG: preconditioner_config,
},
)

Expand Down Expand Up @@ -534,20 +535,24 @@ def _instantiate_shampoo_preconditioner_list(
for state_lists, group in zip(
self._per_group_state_lists, self.param_groups, strict=True
):
state_lists[SHAMPOO_PRECONDITIONER_LIST] = (
EigenvalueCorrectedShampooPreconditionerList
if isinstance(
group[PRECONDITIONER_COMPUTATION_CONFIG], EigenvalueCorrectionConfig
if type(group[PRECONDITIONER_CONFIG]) is ShampooPreconditionerConfig:
preconditioner_list_cls = ShampooPreconditionerList
elif (
type(group[PRECONDITIONER_CONFIG])
is EigenvalueCorrectedShampooPreconditionerConfig
):
preconditioner_list_cls = EigenvalueCorrectedShampooPreconditionerList # type: ignore[assignment]
else:
raise NotImplementedError(
f"{group[PRECONDITIONER_CONFIG]=} not supported!"
)
else ShampooPreconditionerList
)(

state_lists[SHAMPOO_PRECONDITIONER_LIST] = preconditioner_list_cls(
block_list=state_lists[DISTRIBUTOR].global_blocked_params,
state=self.state,
block_info_list=state_lists[DISTRIBUTOR].global_block_info_list,
distributor_selector=state_lists[DISTRIBUTOR].distributor_selector,
preconditioner_computation_config=group[
PRECONDITIONER_COMPUTATION_CONFIG
],
preconditioner_config=group[PRECONDITIONER_CONFIG],
precision_config=group[PRECISION_CONFIG],
beta2=group[BETAS][1],
epsilon=group[EPSILON],
Expand Down Expand Up @@ -588,9 +593,7 @@ def _instantiate_grafting(self) -> None:
is AdamGraftingConfig,
)
else:
raise NotImplementedError(
f"Unsupported grafting config: {group[GRAFTING_CONFIG]=}."
)
raise NotImplementedError(f"{group[GRAFTING_CONFIG]=} not supported!")

@torch.no_grad()
def _instantiate_steps(self) -> None:
Expand Down
29 changes: 17 additions & 12 deletions distributed_shampoo/examples/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@
CommunicationDType,
CoupledHigherOrderConfig,
CoupledNewtonConfig,
DefaultEigenvalueCorrectedShampooConfig,
DefaultShampooConfig,
DefaultSOAPConfig,
DistributedConfig,
DistributedShampoo,
EigenConfig,
EighEigenvalueCorrectionConfig,
GraftingConfig,
PrecisionConfig,
PreconditionerComputationConfig,
QREigenvalueCorrectionConfig,
PreconditionerConfig,
RMSpropGraftingConfig,
SGDGraftingConfig,
ShampooPreconditionerConfig,
)
from distributed_shampoo.examples.convnet import ConvNet

Expand Down Expand Up @@ -497,7 +498,7 @@ def instantiate_optimizer(
precision_config=precision_config,
use_protected_eigh=use_protected_eigh,
track_root_inv_residuals=track_root_inv_residuals,
preconditioner_computation_config=instantiate_preconditioner_computation_config(
preconditioner_config=instantiate_preconditioner_config(
preconditioner_computation_type
),
) # type: ignore[assignment]
Expand Down Expand Up @@ -537,31 +538,35 @@ def instantiate_grafting_config(
raise ValueError(f"Invalid GraftingType {grafting_type}!")


def instantiate_preconditioner_computation_config(
def instantiate_preconditioner_config(
preconditioner_computation_type: PreconditionerComputationType,
) -> PreconditionerComputationConfig:
) -> PreconditionerConfig:
if preconditioner_computation_type == PreconditionerComputationType.EIGEN_ROOT_INV:
return EigenConfig()
return DefaultShampooConfig
elif (
preconditioner_computation_type
== PreconditionerComputationType.COUPLED_NEWTON_ROOT_INV
):
return CoupledNewtonConfig()
return ShampooPreconditionerConfig(
amortized_computation_config=CoupledNewtonConfig(),
)
elif (
preconditioner_computation_type
== PreconditionerComputationType.COUPLED_HIGHER_ORDER_ROOT_INV
):
return CoupledHigherOrderConfig()
return ShampooPreconditionerConfig(
amortized_computation_config=CoupledHigherOrderConfig(),
)
elif (
preconditioner_computation_type
== PreconditionerComputationType.EIGH_EIGENVALUE_CORRECTION
):
return EighEigenvalueCorrectionConfig()
return DefaultEigenvalueCorrectedShampooConfig
elif (
preconditioner_computation_type
== PreconditionerComputationType.QR_EIGENVALUE_CORRECTION
):
return QREigenvalueCorrectionConfig()
return DefaultSOAPConfig
else:
raise ValueError(
f"Invalid PreconditionerComputationType {preconditioner_computation_type}!"
Expand Down
Loading
Loading