diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 234c598..4074fea 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -93,6 +93,12 @@ class PreconditionerConfig(AbstractDataclass): amortized_computation_config: MatrixFunctionConfig num_tolerated_failed_amortized_computations: int = 3 + def __post_init__(self) -> None: + if self.num_tolerated_failed_amortized_computations < 0: + raise ValueError( + f"Invalid num_tolerated_failed_amortized_computations value: {self.num_tolerated_failed_amortized_computations}. Must be >= 0." + ) + @dataclass(kw_only=True) class ShampooPreconditionerConfig(PreconditionerConfig): diff --git a/distributed_shampoo/tests/shampoo_types_test.py b/distributed_shampoo/tests/shampoo_types_test.py index 2ec773f..dd9a59d 100644 --- a/distributed_shampoo/tests/shampoo_types_test.py +++ b/distributed_shampoo/tests/shampoo_types_test.py @@ -9,12 +9,23 @@ import re import unittest -from typing import Type +from abc import ABC, abstractmethod +from typing import Generic, Type, TypeVar from distributed_shampoo.shampoo_types import ( AdaGradGraftingConfig, AdamGraftingConfig, + EigenvalueCorrectedShampooPreconditionerConfig, + PreconditionerConfig, RMSpropGraftingConfig, + ShampooPreconditionerConfig, +) +from matrix_functions_types import ( + DefaultEigenConfig, + DefaultEighConfig, + EigenvectorConfig, + MatrixFunctionConfig, + RootInvConfig, ) @@ -69,3 +80,72 @@ def _get_grafting_config_type( self, ) -> Type[RMSpropGraftingConfig] | Type[AdamGraftingConfig]: return AdamGraftingConfig + + +PreconditionerConfigType = TypeVar( + "PreconditionerConfigType", bound=Type[PreconditionerConfig] +) +AmortizedComputationConfigType = TypeVar( + "AmortizedComputationConfigType", bound=MatrixFunctionConfig +) + + +class AbstractPreconditionerConfigTest: + class PreconditionerConfigTest( + ABC, + unittest.TestCase, + Generic[PreconditionerConfigType, AmortizedComputationConfigType], + ): + def test_illegal_num_tolerated_failed_amortized_computations(self) -> None: + num_tolerated_failed_amortized_computations = -1 + with ( + self.assertRaisesRegex( + ValueError, + re.escape( + f"Invalid num_tolerated_failed_amortized_computations value: " + f"{num_tolerated_failed_amortized_computations}. Must be >= 0." + ), + ), + ): + self._get_preconditioner_config_type()( + amortized_computation_config=self._get_amortized_computation_config(), + num_tolerated_failed_amortized_computations=num_tolerated_failed_amortized_computations, + ) + + @abstractmethod + def _get_preconditioner_config_type( + self, + ) -> PreconditionerConfigType: ... + + @abstractmethod + def _get_amortized_computation_config( + self, + ) -> AmortizedComputationConfigType: ... + + +class ShampooPreconditionerConfigTest( + AbstractPreconditionerConfigTest.PreconditionerConfigTest[ + Type[ShampooPreconditionerConfig], RootInvConfig + ] +): + def _get_amortized_computation_config(self) -> RootInvConfig: + return DefaultEigenConfig + + def _get_preconditioner_config_type( + self, + ) -> Type[ShampooPreconditionerConfig]: + return ShampooPreconditionerConfig + + +class EigenvalueCorrectedShampooPreconditionerConfigTest( + AbstractPreconditionerConfigTest.PreconditionerConfigTest[ + Type[EigenvalueCorrectedShampooPreconditionerConfig], EigenvectorConfig + ] +): + def _get_amortized_computation_config(self) -> EigenvectorConfig: + return DefaultEighConfig + + def _get_preconditioner_config_type( + self, + ) -> Type[EigenvalueCorrectedShampooPreconditionerConfig]: + return EigenvalueCorrectedShampooPreconditionerConfig