diff --git a/distributed_shampoo/README.md b/distributed_shampoo/README.md index 86c8cde..d197cd0 100644 --- a/distributed_shampoo/README.md +++ b/distributed_shampoo/README.md @@ -64,7 +64,7 @@ A few notes on hyperparameters: - We allow for decoupled and coupled weight decay. If one sets `use_decoupled_weight_decay=True`, then you are enabling AdamW-style weight decay, while `use_decoupled_weight_decay=False` corresponds to the normal L2-regularization style weight decay. -- When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectionConfig`, there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. +- When setting `preconditioner_config` as an instance of `EigenvalueCorrectedShampooPreconditionerConfig` (see Example 5), there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. ### Example 1: [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html) with Momentum @@ -221,7 +221,7 @@ optimizer = DistributedShampoo( ) ``` -### Example 5: eigenvalue-corrected Shampoo (SOAP) +### Example 5: eigenvalue-corrected Shampoo/SOAP If we previously used the optimizer: ```python @@ -241,7 +241,10 @@ optimizer = AdamW( we would instead use: ```python import torch -from distributed_shampoo import DistributedShampoo, EighEigenvalueCorrectionConfig +from distributed_shampoo import ( + DistributedShampoo, + DefaultEigenvalueCorrectedShampooConfig, +) model = instantiate_model() @@ -254,9 +257,9 @@ optimizer = DistributedShampoo( max_preconditioner_dim=8192, precondition_frequency=100, use_decoupled_weight_decay=True, - # This can also be set to `QREigenvalueCorrectionConfig` which is less expensive - # and might therefore allow for a smaller `precondition_frequency`. - preconditioner_computation_config=EighEigenvalueCorrectionConfig(), + # This can also be set to `DefaultSOAPConfig` which uses QR decompositions, hence is + # less expensive and might thereby allow for a smaller `precondition_frequency`. + preconditioner_config=DefaultEigenvalueCorrectedShampooConfig, ) ``` diff --git a/distributed_shampoo/__init__.py b/distributed_shampoo/__init__.py index 2b2c45f..97a898a 100644 --- a/distributed_shampoo/__init__.py +++ b/distributed_shampoo/__init__.py @@ -13,14 +13,20 @@ AdamGraftingConfig, CommunicationDType, DDPShampooConfig, + DefaultEigenvalueCorrectedShampooConfig, + DefaultShampooConfig, + DefaultSOAPConfig, DistributedConfig, + EigenvalueCorrectedShampooPreconditionerConfig, FSDPShampooConfig, FullyShardShampooConfig, GraftingConfig, HSDPShampooConfig, PrecisionConfig, + PreconditionerConfig, RMSpropGraftingConfig, SGDGraftingConfig, + ShampooPreconditionerConfig, ShampooPT2CompileConfig, ) from distributed_shampoo.utils.shampoo_fsdp_utils import compile_fsdp_parameter_metadata @@ -28,12 +34,9 @@ CoupledHigherOrderConfig, CoupledNewtonConfig, DefaultEigenConfig, - DefaultEighEigenvalueCorrectionConfig, EigenConfig, - EigenvalueCorrectionConfig, - EighEigenvalueCorrectionConfig, - PreconditionerComputationConfig, - QREigenvalueCorrectionConfig, + EigenvectorConfig, + MatrixFunctionConfig, RootInvConfig, ) @@ -56,17 +59,24 @@ "HSDPShampooConfig", # `precision_config`. "PrecisionConfig", - # `preconditioner_computation_config` options. - "PreconditionerComputationConfig", # Abstract base class. - "RootInvConfig", # Abstract base class (based on `PreconditionerComputationConfig`). - "EigenConfig", - "DefaultEigenConfig", # Default `RootInvConfig`. - "CoupledNewtonConfig", - "CoupledHigherOrderConfig", - "EigenvalueCorrectionConfig", # Abstract base class (based on `PreconditionerComputationConfig`). - "EighEigenvalueCorrectionConfig", - "DefaultEighEigenvalueCorrectionConfig", # Default `EigenvalueCorrectionConfig`. - "QREigenvalueCorrectionConfig", + # `preconditioner_config` options. + "PreconditionerConfig", # Abstract base class. + "ShampooPreconditionerConfig", # Based on `PreconditionerConfig`. + "DefaultShampooConfig", # Default `ShampooPreconditionerConfig` using `EigenConfig`. + "EigenvalueCorrectedShampooPreconditionerConfig", # Based on `PreconditionerConfig`. + "DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighEigenvectorConfig`. + "DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QRConfig`. + # matrix functions configs. + "MatrixFunctionConfig", # Abstract base class. + "RootInvConfig", # Abstract base class (based on `MatrixFunctionConfig`). + "EigenConfig", # Based on `RootInvConfig`. + "DefaultEigenConfig", # Default `RootInvConfig` using `EigenConfig`. + "CoupledNewtonConfig", # Based on `RootInvConfig`. + "CoupledHigherOrderConfig", # Based on `RootInvConfig`. + "EigenvectorConfig", # Abstract base class (based on `MatrixFunctionConfig`). + "EighEigenvectorConfig", # Based on `EigenvectorConfig`. + "DefaultEighEigenvectorConfig", # Default `EigenvectorConfig` using `EighEigenvectorConfig`. + "QRConfig", # Based on `EigenvectorConfig`. # Other utilities. "compile_fsdp_parameter_metadata", # For `FSDPShampooConfig` and `HSDPShampooConfig`. "CommunicationDType", # For `DDPShampooConfig` and `HSDPShampooConfig`. diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index 108ab3e..39fe222 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -24,8 +24,10 @@ BETAS, DAMPENING, DDPShampooConfig, + DefaultShampooConfig, DistributedConfig, DISTRIBUTOR, + EigenvalueCorrectedShampooPreconditionerConfig, EPSILON, FILTERED_GRAD, FILTERED_GRAD_LIST, @@ -48,11 +50,13 @@ PRECISION_CONFIG, PrecisionConfig, PRECONDITION_FREQUENCY, - PRECONDITIONER_COMPUTATION_CONFIG, + PRECONDITIONER_CONFIG, + PreconditionerConfig, PREVIOUS_GRAD_SELECTOR, RMSpropGraftingConfig, SGDGraftingConfig, SHAMPOO_PRECONDITIONER_LIST, + ShampooPreconditionerConfig, ShampooPT2CompileConfig, START_PRECONDITIONING_STEP, STEP, @@ -91,13 +95,7 @@ ) from distributed_shampoo.utils.shampoo_utils import compress_list -from matrix_functions_types import ( - DefaultEigenConfig, - EigenConfig, - EigenvalueCorrectionConfig, - PreconditionerComputationConfig, - RootInvConfig, -) +from matrix_functions_types import EigenConfig, RootInvConfig from torch.optim.optimizer import ParamsT, StateDict logger: logging.Logger = logging.getLogger(__name__) @@ -216,7 +214,7 @@ class DistributedShampoo(torch.optim.Optimizer): updated every iteration while the eigenbasis of Shampoo's preconditioner is only computed every `precondition_frequency` steps. Alternatively, this can be seen as running Adam in the eigenbasis of Shampoo's preconditioner, also known as SOAP. - When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectionConfig`, there is typically no need to use learning + When setting `preconditioner_config` as an instance of `EigenvalueCorrectedShampooPreconditionerConfig`, there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. @@ -236,16 +234,16 @@ class DistributedShampoo(torch.optim.Optimizer): weight_decay (float): Weight decay (L2 penalty). (Default: 0.) max_preconditioner_dim (int): Maximum preconditioner dimension. (Default: 1024) precondition_frequency (int): Frequency of updating all components of the preconditioner. - If this field is an instance RootInvConfig, this is the update frequency of the root inverse of the preconditioner. - If this field is an instance EigenvalueCorrectionConfig, this is the update frequency of the eigenbasis of preconditioner. + If this field is an instance ShampooPreconditionerConfig, this is the update frequency of the root inverse of the preconditioner. + If this field is an instance EigenvalueCorrectedShampooPreconditionerConfig, this is the update frequency of the eigenbasis of preconditioner. (Default: 1) start_preconditioning_step (int): Iteration to start computing inverse preconditioner. If -1, uses the same value as precondition_frequency. (Default: -1) inv_root_override (int, Sequence[int]): Inverse root to use in Shampoo. If a list [l1, l2, ..., lp], then we will use -1 / l1 for 1-D tensor (vectors), -1 / l2 for 2-D tensors (matrices), and so on. If the order of the tensor exceeds the order of the tensor, reverts to the default value. If 0 is used, uses the default inverse - root -1 / (2 * o), where o is the order of the tensor. If preconditioner_computation_config is an instance of - EigenvalueCorrectionConfig, the default is -1 / 2. + root -1 / (2 * o), where o is the order of the tensor. If preconditioner_config is an instance of + EigenvalueCorrectedShampooPreconditionerConfig, the default is -1 / 2. (Default: 0) exponent_multiplier (float | None): **DEPRECATING** Number to be multiplied to the numerator of the inverse root, i.e., eta where the exponent is -eta / (2 * p). (Default: None) @@ -271,10 +269,10 @@ class DistributedShampoo(torch.optim.Optimizer): 3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail. track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes. (Default: False) - preconditioner_computation_config (PreconditionerComputationConfig): Configuration for preconditioner computation. - If this field is an instance RootInvConfig, Shampoo uses the root inverse of the preconditioner. - If this field is an instance EigenvalueCorrectionConfig Shampoo uses corrected the eigenvalues/running Adam in the eigenbasis of preconditioner. - (Default: DefaultEigenConfig) + preconditioner_config (PreconditionerConfig): Configuration for preconditioner computation. + If this field is an instance ShampooPreconditionerConfig, Shampoo uses the root inverse of the preconditioner. + If this field is an instance EigenvalueCorrectedShampooPreconditionerConfig Shampoo uses corrected the eigenvalues/running Adam in the eigenbasis of preconditioner. + (Default: DefaultShampooConfig) """ @@ -305,7 +303,7 @@ def __init__( precision_config: PrecisionConfig | None = None, use_protected_eigh: bool = True, track_root_inv_residuals: bool = False, - preconditioner_computation_config: PreconditionerComputationConfig = DefaultEigenConfig, + preconditioner_config: PreconditionerConfig = DefaultShampooConfig, ) -> None: # Hyperparameter checks. if not lr >= 0.0: @@ -419,11 +417,14 @@ def __init__( "Both preconditioner_dtype and precision_config are provided. Please use only precision_config as preconditioner_dtype is deprecated." ) + amortized_computation_config = ( + preconditioner_config.amortized_computation_config + ) if ( - not isinstance(preconditioner_computation_config, RootInvConfig) + not isinstance(amortized_computation_config, RootInvConfig) ) and track_root_inv_residuals: raise ValueError( - f"{track_root_inv_residuals=} has to be set to False when {preconditioner_computation_config=} is not an instance of RootInvConfig." + f"{track_root_inv_residuals=} has to be set to False when {amortized_computation_config=} is not an instance of RootInvConfig." ) # Create default precision config if it is not provided. @@ -432,14 +433,14 @@ def __init__( # Set exponent multiplier if this is not provided. if ( - isinstance(preconditioner_computation_config, EigenConfig) + isinstance(amortized_computation_config, EigenConfig) and exponent_multiplier is not None ): logger.warning( f"{exponent_multiplier=} is deprecating. Please consider using EigenConfig.exponent_multiplier directly and setting exponent_multipler=None instead in the future." ) - preconditioner_computation_config = dataclasses.replace( - preconditioner_computation_config, + amortized_computation_config = dataclasses.replace( + amortized_computation_config, exponent_multiplier=exponent_multiplier, ) @@ -463,7 +464,7 @@ def __init__( GRAFTING_CONFIG: grafting_config, USE_MERGE_DIMS: use_merge_dims, PRECISION_CONFIG: precision_config, - PRECONDITIONER_COMPUTATION_CONFIG: preconditioner_computation_config, + PRECONDITIONER_CONFIG: preconditioner_config, }, ) @@ -534,20 +535,24 @@ def _instantiate_shampoo_preconditioner_list( for state_lists, group in zip( self._per_group_state_lists, self.param_groups, strict=True ): - state_lists[SHAMPOO_PRECONDITIONER_LIST] = ( - EigenvalueCorrectedShampooPreconditionerList - if isinstance( - group[PRECONDITIONER_COMPUTATION_CONFIG], EigenvalueCorrectionConfig + if type(group[PRECONDITIONER_CONFIG]) is ShampooPreconditionerConfig: + preconditioner_list_cls = ShampooPreconditionerList + elif ( + type(group[PRECONDITIONER_CONFIG]) + is EigenvalueCorrectedShampooPreconditionerConfig + ): + preconditioner_list_cls = EigenvalueCorrectedShampooPreconditionerList # type: ignore[assignment] + else: + raise NotImplementedError( + f"{group[PRECONDITIONER_CONFIG]=} not supported!" ) - else ShampooPreconditionerList - )( + + state_lists[SHAMPOO_PRECONDITIONER_LIST] = preconditioner_list_cls( block_list=state_lists[DISTRIBUTOR].global_blocked_params, state=self.state, block_info_list=state_lists[DISTRIBUTOR].global_block_info_list, distributor_selector=state_lists[DISTRIBUTOR].distributor_selector, - preconditioner_computation_config=group[ - PRECONDITIONER_COMPUTATION_CONFIG - ], + preconditioner_config=group[PRECONDITIONER_CONFIG], precision_config=group[PRECISION_CONFIG], beta2=group[BETAS][1], epsilon=group[EPSILON], @@ -588,9 +593,7 @@ def _instantiate_grafting(self) -> None: is AdamGraftingConfig, ) else: - raise NotImplementedError( - f"Unsupported grafting config: {group[GRAFTING_CONFIG]=}." - ) + raise NotImplementedError(f"{group[GRAFTING_CONFIG]=} not supported!") @torch.no_grad() def _instantiate_steps(self) -> None: diff --git a/distributed_shampoo/examples/trainer_utils.py b/distributed_shampoo/examples/trainer_utils.py index 635a57b..313d980 100644 --- a/distributed_shampoo/examples/trainer_utils.py +++ b/distributed_shampoo/examples/trainer_utils.py @@ -24,16 +24,17 @@ CommunicationDType, CoupledHigherOrderConfig, CoupledNewtonConfig, + DefaultEigenvalueCorrectedShampooConfig, + DefaultShampooConfig, + DefaultSOAPConfig, DistributedConfig, DistributedShampoo, - EigenConfig, - EighEigenvalueCorrectionConfig, GraftingConfig, PrecisionConfig, - PreconditionerComputationConfig, - QREigenvalueCorrectionConfig, + PreconditionerConfig, RMSpropGraftingConfig, SGDGraftingConfig, + ShampooPreconditionerConfig, ) from distributed_shampoo.examples.convnet import ConvNet @@ -497,7 +498,7 @@ def instantiate_optimizer( precision_config=precision_config, use_protected_eigh=use_protected_eigh, track_root_inv_residuals=track_root_inv_residuals, - preconditioner_computation_config=instantiate_preconditioner_computation_config( + preconditioner_config=instantiate_preconditioner_config( preconditioner_computation_type ), ) # type: ignore[assignment] @@ -537,31 +538,35 @@ def instantiate_grafting_config( raise ValueError(f"Invalid GraftingType {grafting_type}!") -def instantiate_preconditioner_computation_config( +def instantiate_preconditioner_config( preconditioner_computation_type: PreconditionerComputationType, -) -> PreconditionerComputationConfig: +) -> PreconditionerConfig: if preconditioner_computation_type == PreconditionerComputationType.EIGEN_ROOT_INV: - return EigenConfig() + return DefaultShampooConfig elif ( preconditioner_computation_type == PreconditionerComputationType.COUPLED_NEWTON_ROOT_INV ): - return CoupledNewtonConfig() + return ShampooPreconditionerConfig( + amortized_computation_config=CoupledNewtonConfig(), + ) elif ( preconditioner_computation_type == PreconditionerComputationType.COUPLED_HIGHER_ORDER_ROOT_INV ): - return CoupledHigherOrderConfig() + return ShampooPreconditionerConfig( + amortized_computation_config=CoupledHigherOrderConfig(), + ) elif ( preconditioner_computation_type == PreconditionerComputationType.EIGH_EIGENVALUE_CORRECTION ): - return EighEigenvalueCorrectionConfig() + return DefaultEigenvalueCorrectedShampooConfig elif ( preconditioner_computation_type == PreconditionerComputationType.QR_EIGENVALUE_CORRECTION ): - return QREigenvalueCorrectionConfig() + return DefaultSOAPConfig else: raise ValueError( f"Invalid PreconditionerComputationType {preconditioner_computation_type}!" diff --git a/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py b/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py index f63ab7a..19e9177 100644 --- a/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py +++ b/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py @@ -17,13 +17,13 @@ import torch from distributed_shampoo.distributed_shampoo import DistributedShampoo +from distributed_shampoo.shampoo_types import ( + DefaultEigenvalueCorrectedShampooConfig, + DefaultSOAPConfig, +) from distributed_shampoo.tests.shampoo_test_utils import ( compare_two_optimizers_on_weight_and_loss, ) -from matrix_functions_types import ( - DefaultEighEigenvalueCorrectionConfig, - QREigenvalueCorrectionConfig, -) from torch.optim.adagrad import Adagrad from torch.optim.adam import Adam from torch.optim.adamw import AdamW @@ -49,12 +49,12 @@ def _optim_factory( def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None: # Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm. - for weight_decay, device, preconditioner_computation_config in product( + for weight_decay, device, preconditioner_config in product( (0.0, 0.3), (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() else (), - (DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()), + (DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig), ): optim_factory = partial( DistributedShampooEigenvalueCorrectionTest._optim_factory, @@ -64,7 +64,7 @@ def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None: with self.subTest( weight_decay=weight_decay, device=device, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ): compare_two_optimizers_on_weight_and_loss( control_optim_factory=partial( @@ -81,19 +81,19 @@ def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None: start_preconditioning_step=math.inf, use_decoupled_weight_decay=False, grafting_config=None, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ), device=device, ) def test_adam_eigenvalue_correction_on_quadratic(self) -> None: # Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm. - for weight_decay, device, preconditioner_computation_config in product( + for weight_decay, device, preconditioner_config in product( (0.0, 0.3), (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() else (), - (DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()), + (DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig), ): optim_factory = partial( DistributedShampooEigenvalueCorrectionTest._optim_factory, @@ -104,7 +104,7 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None: with self.subTest( weight_decay=weight_decay, device=device, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ): compare_two_optimizers_on_weight_and_loss( control_optim_factory=partial( @@ -122,19 +122,19 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None: start_preconditioning_step=math.inf, use_decoupled_weight_decay=False, grafting_config=None, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ), device=device, ) def test_adamw_eigenvalue_correction_on_quadratic(self) -> None: # Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm. - for weight_decay, device, preconditioner_computation_config in product( + for weight_decay, device, preconditioner_config in product( (0.0, 0.3), (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() else (), - (DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()), + (DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig), ): optim_factory = partial( DistributedShampooEigenvalueCorrectionTest._optim_factory, @@ -145,7 +145,7 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None: with self.subTest( weight_decay=weight_decay, device=device, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ): compare_two_optimizers_on_weight_and_loss( control_optim_factory=partial( @@ -163,19 +163,19 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None: start_preconditioning_step=math.inf, use_decoupled_weight_decay=True, grafting_config=None, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ), device=device, ) def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None: # Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm. - for weight_decay, device, preconditioner_computation_config in product( + for weight_decay, device, preconditioner_config in product( (0.0, 0.3), (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() else (), - (DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()), + (DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig), ): optim_factory = partial( DistributedShampooEigenvalueCorrectionTest._optim_factory, @@ -185,7 +185,7 @@ def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None: with self.subTest( weight_decay=weight_decay, device=device, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ): compare_two_optimizers_on_weight_and_loss( control_optim_factory=partial( @@ -206,7 +206,7 @@ def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None: use_decoupled_weight_decay=False, grafting_config=None, use_bias_correction=False, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ), device=device, ) diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index e0dd6a4..306e9e5 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -8,11 +8,20 @@ """ import enum -from dataclasses import dataclass +from dataclasses import dataclass, field import torch from commons import AbstractDataclass + +from matrix_functions_types import ( + DefaultEigenConfig, + DefaultEighEigenvectorConfig, + EigenvectorConfig, + MatrixFunctionConfig, + QRConfig, + RootInvConfig, +) from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import ShardingStrategy from torch.nn.parameter import Parameter @@ -35,7 +44,7 @@ PRECISION_CONFIG = "precision_config" PRECONDITION_FREQUENCY = "precondition_frequency" PRECONDITIONER_DTYPE = "preconditioner_dtype" -PRECONDITIONER_COMPUTATION_CONFIG = "preconditioner_computation_config" +PRECONDITIONER_CONFIG = "preconditioner_config" START_PRECONDITIONING_STEP = "start_preconditioning_step" USE_EIGENVALUE_CORRECTION = "use_eigenvalue_correction" USE_BIAS_CORRECTION = "use_bias_correction" @@ -71,6 +80,58 @@ class PreconditionerValueError(ValueError): ###### DATACLASSES ###### +@dataclass(init=False) +class PreconditionerConfig(AbstractDataclass): + """Configuration for preconditioner computation in DistributedShampoo. + + Args: + amortized_computation_config (MatrixFunctionConfig): Configuration for the amortized computation, e.g., inverse-root or eigenvector computation. + + """ + + amortized_computation_config: MatrixFunctionConfig + + +@dataclass(kw_only=True) +class ShampooPreconditionerConfig(PreconditionerConfig): + """Configuration for Shampoo preconditioner computation. + + Args: + amortized_computation_config (RootInvConfig): Configuration for the inverse-root computation. (Default: DefaultEigenConfig) + + """ + + amortized_computation_config: RootInvConfig = field( + default_factory=lambda: DefaultEigenConfig + ) + + +DefaultShampooConfig = ShampooPreconditionerConfig() + + +@dataclass(kw_only=True) +class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerConfig): + """Configuration for eigenvalue-corrected Shampoo/SOAP preconditioner computation. + + Args: + amortized_computation_config (EigenvectorConfig): Configuration for the eigenvector computation. + (Default: DefaultEighEigenvectorConfig) + + """ + + amortized_computation_config: EigenvectorConfig = field( + default_factory=lambda: DefaultEighEigenvectorConfig + ) + + +DefaultEigenvalueCorrectedShampooConfig = ( + EigenvalueCorrectedShampooPreconditionerConfig() +) +DefaultSOAPConfig = EigenvalueCorrectedShampooPreconditionerConfig( + amortized_computation_config=QRConfig(), +) + + @dataclass class FSDPParameterMetadata: """FSDP Metadata for a parameter. diff --git a/distributed_shampoo/tests/distributed_shampoo_test.py b/distributed_shampoo/tests/distributed_shampoo_test.py index 396db52..92eb0aa 100644 --- a/distributed_shampoo/tests/distributed_shampoo_test.py +++ b/distributed_shampoo/tests/distributed_shampoo_test.py @@ -22,21 +22,21 @@ from distributed_shampoo.shampoo_types import ( AdaGradGraftingConfig, DDPShampooConfig, + DefaultEigenvalueCorrectedShampooConfig, + DefaultShampooConfig, DistributedConfig, GRAFTING_PRECONDITIONER_LIST, GraftingConfig, MASKED_FILTERED_GRAD_LIST, MASKED_MOMENTUM_LIST, PrecisionConfig, + PreconditionerConfig, SGDGraftingConfig, SHAMPOO_PRECONDITIONER_LIST, + ShampooPreconditionerConfig, ShampooPT2CompileConfig, ) from distributed_shampoo.utils.shampoo_quantization import QuantizedTensorList -from matrix_functions_types import ( - DefaultEigenConfig, - DefaultEighEigenvalueCorrectionConfig, -) from torch import nn @@ -46,17 +46,38 @@ def setUp(self) -> None: nn.Linear(5, 10, bias=False), ) - def test_invalid_grafting_config(self) -> None: + def test_invalid_preconditioner_config(self) -> None: with ( mock.patch.object( - distributed_shampoo, "type", side_effect=lambda object: GraftingConfig + distributed_shampoo, + "type", + side_effect=lambda object: { + ShampooPreconditionerConfig: PreconditionerConfig + }.get(type(object), type(object)), ), self.assertRaisesRegex( NotImplementedError, - re.escape( - "Unsupported grafting config: group[GRAFTING_CONFIG]=SGDGraftingConfig" + re.escape("group[PRECONDITIONER_CONFIG]=ShampooPreconditionerConfig"), + ), + ): + DistributedShampoo( + self._model.parameters(), + preconditioner_config=DefaultShampooConfig, + ) + + def test_invalid_grafting_config(self) -> None: + with ( + mock.patch.object( + distributed_shampoo, + "type", + side_effect=lambda object: {SGDGraftingConfig: GraftingConfig}.get( + type(object), type(object) ), ), + self.assertRaisesRegex( + NotImplementedError, + re.escape("group[GRAFTING_CONFIG]=SGDGraftingConfig"), + ), ): DistributedShampoo( self._model.parameters(), @@ -251,7 +272,7 @@ def test_setting_exponent_multiplier_with_eigen_config(self) -> None: lr=0.01, start_preconditioning_step=1, exponent_multiplier=2.0, - preconditioner_computation_config=DefaultEigenConfig, + preconditioner_config=DefaultShampooConfig, ) self.assertCountEqual( [r.msg for r in cm.records], @@ -264,7 +285,7 @@ def test_conflict_eigenvalue_correction_and_track_root_inv_residuals(self) -> No with self.assertRaisesRegex( ValueError, re.escape( - "track_root_inv_residuals=True has to be set to False when preconditioner_computation_config=EighEigenvalueCorrectionConfig(retry_double_precision=True) is not an instance of RootInvConfig." + "track_root_inv_residuals=True has to be set to False when amortized_computation_config=EighEigenvectorConfig(retry_double_precision=True) is not an instance of RootInvConfig." ), ): DistributedShampoo( @@ -272,7 +293,7 @@ def test_conflict_eigenvalue_correction_and_track_root_inv_residuals(self) -> No lr=0.01, start_preconditioning_step=1, track_root_inv_residuals=True, - preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + preconditioner_config=DefaultEigenvalueCorrectedShampooConfig, ) @@ -495,7 +516,7 @@ def setUp(self) -> None: ), "use_merge_dims": True, "precision_config": PrecisionConfig(), - "preconditioner_computation_config": DefaultEigenConfig, + "preconditioner_config": DefaultShampooConfig, } }, } @@ -889,7 +910,7 @@ def _instantiate_optimizer( distributed_config=None, grafting_config=None, precision_config=precision_config, - preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + preconditioner_config=DefaultEigenvalueCorrectedShampooConfig, ) def _assert_state_list_dtype( diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index 7bb3231..ed15da8 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -19,7 +19,11 @@ from typing import Any, cast, Generic, TypeVar import torch -from distributed_shampoo.shampoo_types import PrecisionConfig, PreconditionerValueError +from distributed_shampoo.shampoo_types import ( + PrecisionConfig, + PreconditionerConfig, + PreconditionerValueError, +) from distributed_shampoo.utils.shampoo_block_info import BlockInfo from distributed_shampoo.utils.shampoo_quantization import ( QuantizedTensor, @@ -38,12 +42,7 @@ matrix_inverse_root, ) -from matrix_functions_types import ( - DefaultEigenConfig, - EigenvalueCorrectionConfig, - PreconditionerComputationConfig, - RootInvConfig, -) +from matrix_functions_types import EigenvectorConfig, RootInvConfig from optimizer_modules import OptimizerModule from torch import Tensor from torch.autograd import profiler @@ -428,7 +427,7 @@ class BaseShampooPreconditionerList( distributor_selector (tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter is selected by the current Distributor. precision_config (PrecisionConfig): Data types for optimizer states. (Default: all fields torch.float) - preconditioner_computation_config (PreconditionerComputationConfig): Configuration for preconditioner computation. (Default: DefaultEigenConfig) + preconditioner_config (PreconditionerConfig): Configuration for preconditioner computation. (Default: DefaultShampooConfig) beta2 (float): Exponential moving average factor for Shampoo factor matrices. If beta2 = 1., will use unweighted sum. (Default: 1.0) epsilon (float): Epsilon term for regularizing preconditioner to ensure positive definiteness. (Default: 1e-12) @@ -452,7 +451,7 @@ def __init__( block_info_list: tuple[BlockInfo, ...], distributor_selector: tuple[bool, ...], precision_config: PrecisionConfig, - preconditioner_computation_config: PreconditionerComputationConfig = DefaultEigenConfig, + preconditioner_config: PreconditionerConfig, beta2: float = 1.0, epsilon: float = 1e-12, inv_root_override: int | tuple[int, ...] = 0, @@ -463,7 +462,7 @@ def __init__( # Initialize parameters. self._precision_config = precision_config - self._preconditioner_computation_config = preconditioner_computation_config + self._preconditioner_config = preconditioner_config self._beta2 = beta2 self._epsilon = epsilon self._inv_root_override = inv_root_override @@ -959,20 +958,22 @@ def _amortized_computation(self) -> None: ) # Compute inverse preconditioner. + root_inv_config = cast( + RootInvConfig, + self._preconditioner_config.amortized_computation_config, + ) try: computed_inv_factor_matrix = matrix_inverse_root( A=bias_corrected_factor_matrix, root=Fraction( root / getattr( - self._preconditioner_computation_config, + root_inv_config, "exponent_multiplier", 1, ) ), - root_inv_config=cast( - RootInvConfig, self._preconditioner_computation_config - ), + root_inv_config=root_inv_config, epsilon=self._epsilon, is_diagonal=bool(is_factor_matrix_diagonal), ).to(dtype=inv_factor_matrix.dtype) @@ -983,7 +984,7 @@ def _amortized_computation(self) -> None: else: logger.warning( f"Matrix computation failed for factor matrix {factor_matrix_index} " - f"with {exception=}. Using previous inversed factor matrix and continuing..." + f"with {exception=}. Using previous inverted factor matrix and continuing..." ) # Define computed_inv_factor_matrix to prevent undefined local variable error. computed_inv_factor_matrix = inv_factor_matrix @@ -1020,6 +1021,10 @@ def quantize_preconditioners(self) -> None: def compute_root_inverse_residuals( self, ) -> tuple[tuple[Tensor, ...], tuple[Tensor, ...]]: + root_inv_config = cast( + RootInvConfig, + self._preconditioner_config.amortized_computation_config, + ) relative_errors = [] relative_residuals = [] @@ -1043,15 +1048,13 @@ def compute_root_inverse_residuals( root=Fraction( root / getattr( - self._preconditioner_computation_config, + root_inv_config, "exponent_multiplier", 1, ) ), epsilon=self._epsilon, - root_inv_config=cast( - RootInvConfig, self._preconditioner_computation_config - ), + root_inv_config=root_inv_config, ) relative_errors.append(relative_error) relative_residuals.append(relative_residual) @@ -1277,14 +1280,15 @@ def _amortized_computation(self) -> None: ) # Compute eigenvectors of factor matrix. + eigenvector_computation_config = cast( + EigenvectorConfig, + self._preconditioner_config.amortized_computation_config, + ) try: computed_eigenvectors = matrix_eigenvectors( A=factor_matrix, eigenvectors_estimate=factor_matrix_eigenvectors, - eigenvector_computation_config=cast( - EigenvalueCorrectionConfig, - self._preconditioner_computation_config, - ), + eigenvector_computation_config=eigenvector_computation_config, is_diagonal=bool(is_factor_matrix_diagonal), ) except Exception as exception: diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 6d38af6..8bb061f 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -15,7 +15,13 @@ from unittest import mock import torch -from distributed_shampoo.shampoo_types import PrecisionConfig, PreconditionerValueError +from distributed_shampoo.shampoo_types import ( + DefaultEigenvalueCorrectedShampooConfig, + DefaultShampooConfig, + PrecisionConfig, + PreconditionerValueError, + ShampooPreconditionerConfig, +) from distributed_shampoo.utils import shampoo_preconditioner_list from distributed_shampoo.utils.shampoo_block_info import BlockInfo @@ -29,7 +35,7 @@ ShampooPreconditionerList, ) from distributed_shampoo.utils.shampoo_quantization import QuantizedTensorList -from matrix_functions_types import DefaultEighEigenvalueCorrectionConfig, EigenConfig +from matrix_functions import EigenConfig from torch import Tensor @@ -304,6 +310,7 @@ def test_abstract_methods(self) -> None: ), distributor_selector=(True,), precision_config=PrecisionConfig(), + preconditioner_config=DefaultShampooConfig, beta2=1.0, ) @@ -526,6 +533,7 @@ def _instantiate_preconditioner_list( # type: ignore[override] "inv_root_override": 0, "use_bias_correction": True, "use_protected_eigh": True, + "preconditioner_config": DefaultShampooConfig, } | kwargs return ShampooPreconditionerList( block_list=self._block_list, @@ -533,7 +541,7 @@ def _instantiate_preconditioner_list( # type: ignore[override] block_info_list=self._block_info_list, distributor_selector=self._distributor_selector, precision_config=PrecisionConfig(factor_matrix_dtype=torch.float64), - **kwargs, + **kwargs, # type: ignore[arg-type] ) def test_update_preconditioners_and_precondition(self) -> None: @@ -718,7 +726,9 @@ def test_inverse_roots_from_override( """ Tests that the inverse roots are computed correctly from inv_root_override. """ - preconditioner_computation_config = EigenConfig(exponent_multiplier=2.0) + preconditioner_config = ShampooPreconditionerConfig( + amortized_computation_config=EigenConfig(exponent_multiplier=2.0), + ) masked_grad_list1 = ( torch.tensor([1.0, 0.0]), @@ -743,7 +753,7 @@ def test_inverse_roots_from_override( beta2=1.0, use_bias_correction=True, inv_root_override=inv_root_override, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ), masked_grad_lists=[masked_grad_list1, masked_grad_list2], masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, @@ -766,6 +776,7 @@ def test_compute_root_inverse_residuals(self) -> None: block_info_list=(self._block_info_list[0],), distributor_selector=(self._distributor_selector[0],), precision_config=PrecisionConfig(), + preconditioner_config=DefaultShampooConfig, epsilon=0.0, ) @@ -811,7 +822,7 @@ def _instantiate_preconditioner_list( # type: ignore[override] "inv_root_override": 0, "use_bias_correction": True, "use_protected_eigh": True, - "preconditioner_computation_config": DefaultEighEigenvalueCorrectionConfig, + "preconditioner_config": DefaultEigenvalueCorrectedShampooConfig, } | kwargs return EigenvalueCorrectedShampooPreconditionerList( block_list=self._block_list, @@ -1044,7 +1055,7 @@ def test_inverse_roots_from_override( beta2=1.0, use_bias_correction=True, inv_root_override=inv_root_override, - preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + preconditioner_config=DefaultEigenvalueCorrectedShampooConfig, ), masked_grad_lists=[masked_grad_list1, masked_grad_list2], masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, diff --git a/matrix_functions.py b/matrix_functions.py index 70b1366..b3c2ed1 100644 --- a/matrix_functions.py +++ b/matrix_functions.py @@ -20,11 +20,11 @@ CoupledHigherOrderConfig, CoupledNewtonConfig, DefaultEigenConfig, - DefaultEighEigenvalueCorrectionConfig, + DefaultEighEigenvectorConfig, EigenConfig, - EigenvalueCorrectionConfig, - EighEigenvalueCorrectionConfig, - QREigenvalueCorrectionConfig, + EigenvectorConfig, + EighEigenvectorConfig, + QRConfig, RootInvConfig, ) @@ -166,7 +166,7 @@ def _matrix_inverse_root_diagonal( return torch.diag((torch.diagonal(A) + epsilon).pow(torch.as_tensor(-1.0 / root))) -def _compute_eigenvalue_decomposition( +def matrix_eigenvalue_decomposition( A: Tensor, retry_double_precision: bool = True, ) -> tuple[Tensor, Tensor]: @@ -231,7 +231,7 @@ def _matrix_inverse_root_eigen( raise ValueError(f"Root {root} should be positive!") # compute eigendecomposition and compute minimum eigenvalue - L, Q = _compute_eigenvalue_decomposition( + L, Q = matrix_eigenvalue_decomposition( A, retry_double_precision=retry_double_precision ) @@ -599,7 +599,7 @@ def compute_matrix_root_inverse_residuals( def matrix_eigenvectors( A: Tensor, eigenvectors_estimate: Tensor | None = None, - eigenvector_computation_config: EigenvalueCorrectionConfig = DefaultEighEigenvalueCorrectionConfig, + eigenvector_computation_config: EigenvectorConfig = DefaultEighEigenvectorConfig, is_diagonal: bool = False, ) -> Tensor: """Compute eigenvectors of matrix using eigendecomposition of symmetric positive (semi-)definite matrix. @@ -612,8 +612,8 @@ def matrix_eigenvectors( A (Tensor): Square matrix of interest. eigenvectors_estimate (Tensor | None): The current estimate of the eigenvectors of A. (Default: None) - eigenvector_computation_config (EigenvalueCorrectionConfig): Determines how eigenvectors are computed. - (Default: DefaultEighEigenvalueCorrectionConfig) + eigenvector_computation_config (EigenvectorConfig): Determines how eigenvectors are computed. + (Default: DefaultEighEigenvectorConfig) is_diagonal (bool): Whether A is diagonal. (Default: False) Returns: @@ -638,15 +638,15 @@ def matrix_eigenvectors( device=A.device, ) - if type(eigenvector_computation_config) is EighEigenvalueCorrectionConfig: - return _compute_eigenvalue_decomposition( + if type(eigenvector_computation_config) is EighEigenvectorConfig: + return matrix_eigenvalue_decomposition( A, retry_double_precision=eigenvector_computation_config.retry_double_precision, )[1] - elif type(eigenvector_computation_config) is QREigenvalueCorrectionConfig: + elif type(eigenvector_computation_config) is QRConfig: assert ( eigenvectors_estimate is not None - ), "Estimate of eigenvectors is required when using QREigenvalueCorrectionConfig." + ), "Estimate of eigenvectors is required when using QRConfig." return _compute_orthogonal_iterations( A, eigenvectors_estimate=eigenvectors_estimate, @@ -685,7 +685,7 @@ def _compute_orthogonal_iterations( """ if not eigenvectors_estimate.any(): - return _compute_eigenvalue_decomposition(A)[1] + return matrix_eigenvalue_decomposition(A)[1] # Perform orthogonal/simultaneous iterations (QR algorithm). Q = eigenvectors_estimate diff --git a/matrix_functions_types.py b/matrix_functions_types.py index af267d3..2fed2a3 100644 --- a/matrix_functions_types.py +++ b/matrix_functions_types.py @@ -13,30 +13,42 @@ @dataclass(init=False) -class PreconditionerComputationConfig(AbstractDataclass): - """Configuration for preconditioner computation in Shampoo.""" +class MatrixFunctionConfig(AbstractDataclass): + """Base dataclass for matrix function configurations.""" + + +@dataclass(kw_only=True) +class EigenvalueDecompositionConfig(MatrixFunctionConfig): + """Configuration for eigenvalue decomposition. + + Args: + retry_double_precision (bool): Whether to re-trying eigendecomposition with higher (double) precision if lower precision fails due + to CuSOLVER failure. (Default: True) + + """ + + retry_double_precision: bool = True @dataclass(init=False) -class RootInvConfig(PreconditionerComputationConfig): - """Base dataclass for matrix root inverse method configurations in Shampoo.""" +class RootInvConfig(MatrixFunctionConfig): + """Base dataclass for matrix root inverse method configurations.""" @dataclass(kw_only=True) -class EigenConfig(RootInvConfig): - """Configuration for eigendecomposition method in Shampoo. +class EigenConfig(RootInvConfig, EigenvalueDecompositionConfig): + """Configuration for matrix root inverse via an eigendecomposition. Args: - make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True) - retry_double_precision (bool): Whether to re-trying eigendecomposition with higher(double) precision if lower precision fails due + retry_double_precision (bool): Whether to re-trying eigendecomposition with higher (double) precision if lower precision fails due to CuSOLVER failure. (Default: True) + make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True) exponent_multiplier (float): Number to be multiplied to the numerator of the inverse root, i.e., eta where the exponent is -eta / (2 * p). (Default: 1.0) """ make_positive_semidefinite: bool = True - retry_double_precision: bool = True exponent_multiplier: float = 1.0 @@ -45,7 +57,7 @@ class EigenConfig(RootInvConfig): @dataclass(kw_only=True) class CoupledNewtonConfig(RootInvConfig): - """Configuration for coupled Newton method in Shampoo. + """Configuration for matrix root inverse via coupled Newton method. Args: max_iterations (int): Maximum number of iterations for coupled Newton iteration. (Default: 100) @@ -59,7 +71,7 @@ class CoupledNewtonConfig(RootInvConfig): @dataclass(kw_only=True) class CoupledHigherOrderConfig(RootInvConfig): - """Configuration for coupled higher-order method in Shampoo. + """Configuration for matrix root inverse via coupled higher-order method. Args: rel_epsilon (float): Relative epsilon for coupled higher order method. Adds epsilon * lambda_max * I to matrix @@ -81,29 +93,27 @@ class CoupledHigherOrderConfig(RootInvConfig): @dataclass(init=False) -class EigenvalueCorrectionConfig(PreconditionerComputationConfig): - """Base dataclass for matrix eigenvector method configurations in eigenvalue-corrected Shampoo.""" +class EigenvectorConfig(MatrixFunctionConfig): + """Base dataclass for matrix eigenvector method configurations.""" @dataclass(kw_only=True) -class EighEigenvalueCorrectionConfig(EigenvalueCorrectionConfig): - """Configuration for eigendecomposition method used in eigenvalue-corrected Shampoo. +class EighEigenvectorConfig(EigenvectorConfig, EigenvalueDecompositionConfig): + """Configuration for eigenvectors via an eigendecomposition. Args: - retry_double_precision (bool): Whether to re-trying eigendecomposition with higher(double) precision if lower precision fails due + retry_double_precision (bool): Whether to re-trying eigendecomposition with higher (double) precision if lower precision fails due to CuSOLVER failure. (Default: True) """ - retry_double_precision: bool = True - -DefaultEighEigenvalueCorrectionConfig = EighEigenvalueCorrectionConfig() +DefaultEighEigenvectorConfig = EighEigenvectorConfig() @dataclass(kw_only=True) -class QREigenvalueCorrectionConfig(EigenvalueCorrectionConfig): - """Configuration for orthogonal/simultaneous iterations (QR algorithm) used in eigenvalue-corrected Shampoo. +class QRConfig(EigenvectorConfig): + """Configuration for eigenvectors via orthogonal/simultaneous iterations/QR algorithm. Args: max_iterations (int): The maximum number of iterations to perform. (Default: 1) diff --git a/tests/matrix_functions_test.py b/tests/matrix_functions_test.py index 1658201..3d31a7b 100644 --- a/tests/matrix_functions_test.py +++ b/tests/matrix_functions_test.py @@ -35,8 +35,8 @@ CoupledHigherOrderConfig, CoupledNewtonConfig, EigenConfig, - EigenvalueCorrectionConfig, - QREigenvalueCorrectionConfig, + EigenvectorConfig, + QRConfig, RootInvConfig, ) from torch import Tensor @@ -859,18 +859,16 @@ def test_matrix_eigenvectors(self) -> None: rtol=rtol, ) - # Tests for `QREigenvalueCorrectionConfig`. + # Tests for `QRConfig`. initialization_strategies = { "zero": lambda A: torch.zeros_like(A), "identity": lambda A: torch.eye(A.shape[0], dtype=A.dtype, device=A.device), "exact": lambda A: matrix_eigenvectors(A), # Eigendecomposition. } for name, initialization_fn in initialization_strategies.items(): - with self.subTest( - f"Test with QREigenvalueCorrectionConfig with {name} initialization." - ): + with self.subTest(f"Test with QRConfig with {name} initialization."): # Set `max_iterations` to large int to run until numerical tolerance. - qr_config = QREigenvalueCorrectionConfig(max_iterations=10_000) + qr_config = QRConfig(max_iterations=10_000) for A, expected_eigenvectors in zip( A_list, expected_eigenvectors_list, strict=True ): @@ -899,12 +897,12 @@ def test_invalid_eigenvalue_correction_config( mock.patch.object( matrix_functions, "type", - side_effect=lambda object: EigenvalueCorrectionConfig, + side_effect=lambda object: EigenvectorConfig, ), self.assertRaisesRegex( NotImplementedError, re.escape( - "Eigenvector computation method is not implemented! Specified eigenvector method is eigenvector_computation_config=EighEigenvalueCorrectionConfig(retry_double_precision=True)." + "Eigenvector computation method is not implemented! Specified eigenvector method is eigenvector_computation_config=EighEigenvectorConfig(retry_double_precision=True)." ), ), ):