diff --git a/distributed_shampoo/__init__.py b/distributed_shampoo/__init__.py index da71b9e..97a898a 100644 --- a/distributed_shampoo/__init__.py +++ b/distributed_shampoo/__init__.py @@ -64,7 +64,7 @@ "ShampooPreconditionerConfig", # Based on `PreconditionerConfig`. "DefaultShampooConfig", # Default `ShampooPreconditionerConfig` using `EigenConfig`. "EigenvalueCorrectedShampooPreconditionerConfig", # Based on `PreconditionerConfig`. - "DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighConfig`. + "DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighEigenvectorConfig`. "DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QRConfig`. # matrix functions configs. "MatrixFunctionConfig", # Abstract base class. @@ -74,8 +74,8 @@ "CoupledNewtonConfig", # Based on `RootInvConfig`. "CoupledHigherOrderConfig", # Based on `RootInvConfig`. "EigenvectorConfig", # Abstract base class (based on `MatrixFunctionConfig`). - "EighConfig", # Based on `EigenvectorConfig`. - "DefaultEighConfig", # Default `EigenvectorConfig` using `EighConfig`. + "EighEigenvectorConfig", # Based on `EigenvectorConfig`. + "DefaultEighEigenvectorConfig", # Default `EigenvectorConfig` using `EighEigenvectorConfig`. "QRConfig", # Based on `EigenvectorConfig`. # Other utilities. "compile_fsdp_parameter_metadata", # For `FSDPShampooConfig` and `HSDPShampooConfig`. diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index bf5ccdd..306e9e5 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -16,7 +16,7 @@ from matrix_functions_types import ( DefaultEigenConfig, - DefaultEighConfig, + DefaultEighEigenvectorConfig, EigenvectorConfig, MatrixFunctionConfig, QRConfig, @@ -114,12 +114,13 @@ class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerConfig): """Configuration for eigenvalue-corrected Shampoo/SOAP preconditioner computation. Args: - amortized_computation_config (EigenvectorConfig): Configuration for the eigenvector computation. (Default: DefaultEighConfig) + amortized_computation_config (EigenvectorConfig): Configuration for the eigenvector computation. + (Default: DefaultEighEigenvectorConfig) """ amortized_computation_config: EigenvectorConfig = field( - default_factory=lambda: DefaultEighConfig + default_factory=lambda: DefaultEighEigenvectorConfig ) diff --git a/distributed_shampoo/tests/distributed_shampoo_test.py b/distributed_shampoo/tests/distributed_shampoo_test.py index 01d841f..92eb0aa 100644 --- a/distributed_shampoo/tests/distributed_shampoo_test.py +++ b/distributed_shampoo/tests/distributed_shampoo_test.py @@ -285,7 +285,7 @@ def test_conflict_eigenvalue_correction_and_track_root_inv_residuals(self) -> No with self.assertRaisesRegex( ValueError, re.escape( - "track_root_inv_residuals=True has to be set to False when amortized_computation_config=EighConfig(retry_double_precision=True) is not an instance of RootInvConfig." + "track_root_inv_residuals=True has to be set to False when amortized_computation_config=EighEigenvectorConfig(retry_double_precision=True) is not an instance of RootInvConfig." ), ): DistributedShampoo( diff --git a/matrix_functions.py b/matrix_functions.py index 9a14093..b3c2ed1 100644 --- a/matrix_functions.py +++ b/matrix_functions.py @@ -20,10 +20,10 @@ CoupledHigherOrderConfig, CoupledNewtonConfig, DefaultEigenConfig, - DefaultEighConfig, + DefaultEighEigenvectorConfig, EigenConfig, EigenvectorConfig, - EighConfig, + EighEigenvectorConfig, QRConfig, RootInvConfig, ) @@ -166,7 +166,7 @@ def _matrix_inverse_root_diagonal( return torch.diag((torch.diagonal(A) + epsilon).pow(torch.as_tensor(-1.0 / root))) -def _compute_eigenvalue_decomposition( +def matrix_eigenvalue_decomposition( A: Tensor, retry_double_precision: bool = True, ) -> tuple[Tensor, Tensor]: @@ -231,7 +231,7 @@ def _matrix_inverse_root_eigen( raise ValueError(f"Root {root} should be positive!") # compute eigendecomposition and compute minimum eigenvalue - L, Q = _compute_eigenvalue_decomposition( + L, Q = matrix_eigenvalue_decomposition( A, retry_double_precision=retry_double_precision ) @@ -599,7 +599,7 @@ def compute_matrix_root_inverse_residuals( def matrix_eigenvectors( A: Tensor, eigenvectors_estimate: Tensor | None = None, - eigenvector_computation_config: EigenvectorConfig = DefaultEighConfig, + eigenvector_computation_config: EigenvectorConfig = DefaultEighEigenvectorConfig, is_diagonal: bool = False, ) -> Tensor: """Compute eigenvectors of matrix using eigendecomposition of symmetric positive (semi-)definite matrix. @@ -613,7 +613,7 @@ def matrix_eigenvectors( eigenvectors_estimate (Tensor | None): The current estimate of the eigenvectors of A. (Default: None) eigenvector_computation_config (EigenvectorConfig): Determines how eigenvectors are computed. - (Default: DefaultEighConfig) + (Default: DefaultEighEigenvectorConfig) is_diagonal (bool): Whether A is diagonal. (Default: False) Returns: @@ -638,8 +638,8 @@ def matrix_eigenvectors( device=A.device, ) - if type(eigenvector_computation_config) is EighConfig: - return _compute_eigenvalue_decomposition( + if type(eigenvector_computation_config) is EighEigenvectorConfig: + return matrix_eigenvalue_decomposition( A, retry_double_precision=eigenvector_computation_config.retry_double_precision, )[1] @@ -685,7 +685,7 @@ def _compute_orthogonal_iterations( """ if not eigenvectors_estimate.any(): - return _compute_eigenvalue_decomposition(A)[1] + return matrix_eigenvalue_decomposition(A)[1] # Perform orthogonal/simultaneous iterations (QR algorithm). Q = eigenvectors_estimate diff --git a/matrix_functions_types.py b/matrix_functions_types.py index 4a58450..2fed2a3 100644 --- a/matrix_functions_types.py +++ b/matrix_functions_types.py @@ -17,26 +17,38 @@ class MatrixFunctionConfig(AbstractDataclass): """Base dataclass for matrix function configurations.""" +@dataclass(kw_only=True) +class EigenvalueDecompositionConfig(MatrixFunctionConfig): + """Configuration for eigenvalue decomposition. + + Args: + retry_double_precision (bool): Whether to re-trying eigendecomposition with higher (double) precision if lower precision fails due + to CuSOLVER failure. (Default: True) + + """ + + retry_double_precision: bool = True + + @dataclass(init=False) class RootInvConfig(MatrixFunctionConfig): - """Base dataclass for matrix root inverse (`matrix_inverse_root`) method configurations.""" + """Base dataclass for matrix root inverse method configurations.""" @dataclass(kw_only=True) -class EigenConfig(RootInvConfig): - """Configuration for eigendecomposition (`_matrix_inverse_root_eigen`) method. +class EigenConfig(RootInvConfig, EigenvalueDecompositionConfig): + """Configuration for matrix root inverse via an eigendecomposition. Args: - make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True) - retry_double_precision (bool): Whether to re-trying eigendecomposition with higher(double) precision if lower precision fails due + retry_double_precision (bool): Whether to re-trying eigendecomposition with higher (double) precision if lower precision fails due to CuSOLVER failure. (Default: True) + make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True) exponent_multiplier (float): Number to be multiplied to the numerator of the inverse root, i.e., eta where the exponent is -eta / (2 * p). (Default: 1.0) """ make_positive_semidefinite: bool = True - retry_double_precision: bool = True exponent_multiplier: float = 1.0 @@ -45,7 +57,7 @@ class EigenConfig(RootInvConfig): @dataclass(kw_only=True) class CoupledNewtonConfig(RootInvConfig): - """Configuration for coupled Newton (`_matrix_inverse_root_newton`) method. + """Configuration for matrix root inverse via coupled Newton method. Args: max_iterations (int): Maximum number of iterations for coupled Newton iteration. (Default: 100) @@ -59,7 +71,7 @@ class CoupledNewtonConfig(RootInvConfig): @dataclass(kw_only=True) class CoupledHigherOrderConfig(RootInvConfig): - """Configuration for coupled higher-order (`_matrix_inverse_root_higher_order`) method. + """Configuration for matrix root inverse via coupled higher-order method. Args: rel_epsilon (float): Relative epsilon for coupled higher order method. Adds epsilon * lambda_max * I to matrix @@ -82,28 +94,26 @@ class CoupledHigherOrderConfig(RootInvConfig): @dataclass(init=False) class EigenvectorConfig(MatrixFunctionConfig): - """Base dataclass for matrix eigenvector (`matrix_eigenvectors`) method.""" + """Base dataclass for matrix eigenvector method configurations.""" @dataclass(kw_only=True) -class EighConfig(EigenvectorConfig): - """Configuration for eigendecomposition (`_compute_eigenvalue_decomposition`) method. +class EighEigenvectorConfig(EigenvectorConfig, EigenvalueDecompositionConfig): + """Configuration for eigenvectors via an eigendecomposition. Args: - retry_double_precision (bool): Whether to re-trying eigendecomposition with higher(double) precision if lower precision fails due + retry_double_precision (bool): Whether to re-trying eigendecomposition with higher (double) precision if lower precision fails due to CuSOLVER failure. (Default: True) """ - retry_double_precision: bool = True - -DefaultEighConfig = EighConfig() +DefaultEighEigenvectorConfig = EighEigenvectorConfig() @dataclass(kw_only=True) class QRConfig(EigenvectorConfig): - """Configuration for orthogonal/simultaneous iterations/QR algorithm (`_compute_orthogonal_iterations`). + """Configuration for eigenvectors via orthogonal/simultaneous iterations/QR algorithm. Args: max_iterations (int): The maximum number of iterations to perform. (Default: 1) diff --git a/tests/matrix_functions_test.py b/tests/matrix_functions_test.py index 61f7069..3d31a7b 100644 --- a/tests/matrix_functions_test.py +++ b/tests/matrix_functions_test.py @@ -902,7 +902,7 @@ def test_invalid_eigenvalue_correction_config( self.assertRaisesRegex( NotImplementedError, re.escape( - "Eigenvector computation method is not implemented! Specified eigenvector method is eigenvector_computation_config=EighConfig(retry_double_precision=True)." + "Eigenvector computation method is not implemented! Specified eigenvector method is eigenvector_computation_config=EighEigenvectorConfig(retry_double_precision=True)." ), ), ):