Skip to content

Commit

Permalink
Merge pull request #62 from facebookresearch/rename-eighconfig
Browse files Browse the repository at this point in the history
Rename `EighConfig` and add `EigenDecompositionConfig`
  • Loading branch information
runame authored Dec 17, 2024
2 parents 273e0a1 + 36e9131 commit 839a40a
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 33 deletions.
6 changes: 3 additions & 3 deletions distributed_shampoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"ShampooPreconditionerConfig", # Based on `PreconditionerConfig`.
"DefaultShampooConfig", # Default `ShampooPreconditionerConfig` using `EigenConfig`.
"EigenvalueCorrectedShampooPreconditionerConfig", # Based on `PreconditionerConfig`.
"DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighConfig`.
"DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighEigenvectorConfig`.
"DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QRConfig`.
# matrix functions configs.
"MatrixFunctionConfig", # Abstract base class.
Expand All @@ -74,8 +74,8 @@
"CoupledNewtonConfig", # Based on `RootInvConfig`.
"CoupledHigherOrderConfig", # Based on `RootInvConfig`.
"EigenvectorConfig", # Abstract base class (based on `MatrixFunctionConfig`).
"EighConfig", # Based on `EigenvectorConfig`.
"DefaultEighConfig", # Default `EigenvectorConfig` using `EighConfig`.
"EighEigenvectorConfig", # Based on `EigenvectorConfig`.
"DefaultEighEigenvectorConfig", # Default `EigenvectorConfig` using `EighEigenvectorConfig`.
"QRConfig", # Based on `EigenvectorConfig`.
# Other utilities.
"compile_fsdp_parameter_metadata", # For `FSDPShampooConfig` and `HSDPShampooConfig`.
Expand Down
7 changes: 4 additions & 3 deletions distributed_shampoo/shampoo_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from matrix_functions_types import (
DefaultEigenConfig,
DefaultEighConfig,
DefaultEighEigenvectorConfig,
EigenvectorConfig,
MatrixFunctionConfig,
QRConfig,
Expand Down Expand Up @@ -114,12 +114,13 @@ class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerConfig):
"""Configuration for eigenvalue-corrected Shampoo/SOAP preconditioner computation.
Args:
amortized_computation_config (EigenvectorConfig): Configuration for the eigenvector computation. (Default: DefaultEighConfig)
amortized_computation_config (EigenvectorConfig): Configuration for the eigenvector computation.
(Default: DefaultEighEigenvectorConfig)
"""

amortized_computation_config: EigenvectorConfig = field(
default_factory=lambda: DefaultEighConfig
default_factory=lambda: DefaultEighEigenvectorConfig
)


Expand Down
2 changes: 1 addition & 1 deletion distributed_shampoo/tests/distributed_shampoo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def test_conflict_eigenvalue_correction_and_track_root_inv_residuals(self) -> No
with self.assertRaisesRegex(
ValueError,
re.escape(
"track_root_inv_residuals=True has to be set to False when amortized_computation_config=EighConfig(retry_double_precision=True) is not an instance of RootInvConfig."
"track_root_inv_residuals=True has to be set to False when amortized_computation_config=EighEigenvectorConfig(retry_double_precision=True) is not an instance of RootInvConfig."
),
):
DistributedShampoo(
Expand Down
18 changes: 9 additions & 9 deletions matrix_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
CoupledHigherOrderConfig,
CoupledNewtonConfig,
DefaultEigenConfig,
DefaultEighConfig,
DefaultEighEigenvectorConfig,
EigenConfig,
EigenvectorConfig,
EighConfig,
EighEigenvectorConfig,
QRConfig,
RootInvConfig,
)
Expand Down Expand Up @@ -166,7 +166,7 @@ def _matrix_inverse_root_diagonal(
return torch.diag((torch.diagonal(A) + epsilon).pow(torch.as_tensor(-1.0 / root)))


def _compute_eigenvalue_decomposition(
def matrix_eigenvalue_decomposition(
A: Tensor,
retry_double_precision: bool = True,
) -> tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -231,7 +231,7 @@ def _matrix_inverse_root_eigen(
raise ValueError(f"Root {root} should be positive!")

# compute eigendecomposition and compute minimum eigenvalue
L, Q = _compute_eigenvalue_decomposition(
L, Q = matrix_eigenvalue_decomposition(
A, retry_double_precision=retry_double_precision
)

Expand Down Expand Up @@ -599,7 +599,7 @@ def compute_matrix_root_inverse_residuals(
def matrix_eigenvectors(
A: Tensor,
eigenvectors_estimate: Tensor | None = None,
eigenvector_computation_config: EigenvectorConfig = DefaultEighConfig,
eigenvector_computation_config: EigenvectorConfig = DefaultEighEigenvectorConfig,
is_diagonal: bool = False,
) -> Tensor:
"""Compute eigenvectors of matrix using eigendecomposition of symmetric positive (semi-)definite matrix.
Expand All @@ -613,7 +613,7 @@ def matrix_eigenvectors(
eigenvectors_estimate (Tensor | None): The current estimate of the eigenvectors of A.
(Default: None)
eigenvector_computation_config (EigenvectorConfig): Determines how eigenvectors are computed.
(Default: DefaultEighConfig)
(Default: DefaultEighEigenvectorConfig)
is_diagonal (bool): Whether A is diagonal. (Default: False)
Returns:
Expand All @@ -638,8 +638,8 @@ def matrix_eigenvectors(
device=A.device,
)

if type(eigenvector_computation_config) is EighConfig:
return _compute_eigenvalue_decomposition(
if type(eigenvector_computation_config) is EighEigenvectorConfig:
return matrix_eigenvalue_decomposition(
A,
retry_double_precision=eigenvector_computation_config.retry_double_precision,
)[1]
Expand Down Expand Up @@ -685,7 +685,7 @@ def _compute_orthogonal_iterations(
"""
if not eigenvectors_estimate.any():
return _compute_eigenvalue_decomposition(A)[1]
return matrix_eigenvalue_decomposition(A)[1]

# Perform orthogonal/simultaneous iterations (QR algorithm).
Q = eigenvectors_estimate
Expand Down
42 changes: 26 additions & 16 deletions matrix_functions_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,38 @@ class MatrixFunctionConfig(AbstractDataclass):
"""Base dataclass for matrix function configurations."""


@dataclass(kw_only=True)
class EigenvalueDecompositionConfig(MatrixFunctionConfig):
"""Configuration for eigenvalue decomposition.
Args:
retry_double_precision (bool): Whether to re-trying eigendecomposition with higher (double) precision if lower precision fails due
to CuSOLVER failure. (Default: True)
"""

retry_double_precision: bool = True


@dataclass(init=False)
class RootInvConfig(MatrixFunctionConfig):
"""Base dataclass for matrix root inverse (`matrix_inverse_root`) method configurations."""
"""Base dataclass for matrix root inverse method configurations."""


@dataclass(kw_only=True)
class EigenConfig(RootInvConfig):
"""Configuration for eigendecomposition (`_matrix_inverse_root_eigen`) method.
class EigenConfig(RootInvConfig, EigenvalueDecompositionConfig):
"""Configuration for matrix root inverse via an eigendecomposition.
Args:
make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True)
retry_double_precision (bool): Whether to re-trying eigendecomposition with higher(double) precision if lower precision fails due
retry_double_precision (bool): Whether to re-trying eigendecomposition with higher (double) precision if lower precision fails due
to CuSOLVER failure. (Default: True)
make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True)
exponent_multiplier (float): Number to be multiplied to the numerator of the inverse root, i.e., eta where the
exponent is -eta / (2 * p). (Default: 1.0)
"""

make_positive_semidefinite: bool = True
retry_double_precision: bool = True
exponent_multiplier: float = 1.0


Expand All @@ -45,7 +57,7 @@ class EigenConfig(RootInvConfig):

@dataclass(kw_only=True)
class CoupledNewtonConfig(RootInvConfig):
"""Configuration for coupled Newton (`_matrix_inverse_root_newton`) method.
"""Configuration for matrix root inverse via coupled Newton method.
Args:
max_iterations (int): Maximum number of iterations for coupled Newton iteration. (Default: 100)
Expand All @@ -59,7 +71,7 @@ class CoupledNewtonConfig(RootInvConfig):

@dataclass(kw_only=True)
class CoupledHigherOrderConfig(RootInvConfig):
"""Configuration for coupled higher-order (`_matrix_inverse_root_higher_order`) method.
"""Configuration for matrix root inverse via coupled higher-order method.
Args:
rel_epsilon (float): Relative epsilon for coupled higher order method. Adds epsilon * lambda_max * I to matrix
Expand All @@ -82,28 +94,26 @@ class CoupledHigherOrderConfig(RootInvConfig):

@dataclass(init=False)
class EigenvectorConfig(MatrixFunctionConfig):
"""Base dataclass for matrix eigenvector (`matrix_eigenvectors`) method."""
"""Base dataclass for matrix eigenvector method configurations."""


@dataclass(kw_only=True)
class EighConfig(EigenvectorConfig):
"""Configuration for eigendecomposition (`_compute_eigenvalue_decomposition`) method.
class EighEigenvectorConfig(EigenvectorConfig, EigenvalueDecompositionConfig):
"""Configuration for eigenvectors via an eigendecomposition.
Args:
retry_double_precision (bool): Whether to re-trying eigendecomposition with higher(double) precision if lower precision fails due
retry_double_precision (bool): Whether to re-trying eigendecomposition with higher (double) precision if lower precision fails due
to CuSOLVER failure. (Default: True)
"""

retry_double_precision: bool = True


DefaultEighConfig = EighConfig()
DefaultEighEigenvectorConfig = EighEigenvectorConfig()


@dataclass(kw_only=True)
class QRConfig(EigenvectorConfig):
"""Configuration for orthogonal/simultaneous iterations/QR algorithm (`_compute_orthogonal_iterations`).
"""Configuration for eigenvectors via orthogonal/simultaneous iterations/QR algorithm.
Args:
max_iterations (int): The maximum number of iterations to perform. (Default: 1)
Expand Down
2 changes: 1 addition & 1 deletion tests/matrix_functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ def test_invalid_eigenvalue_correction_config(
self.assertRaisesRegex(
NotImplementedError,
re.escape(
"Eigenvector computation method is not implemented! Specified eigenvector method is eigenvector_computation_config=EighConfig(retry_double_precision=True)."
"Eigenvector computation method is not implemented! Specified eigenvector method is eigenvector_computation_config=EighEigenvectorConfig(retry_double_precision=True)."
),
),
):
Expand Down

0 comments on commit 839a40a

Please sign in to comment.