Skip to content

Commit

Permalink
Add check that tolerance value non-negative
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Dec 9, 2024
1 parent 9a137b5 commit 03245f5
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 1 deletion.
6 changes: 6 additions & 0 deletions distributed_shampoo/shampoo_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ class PreconditionerConfig(AbstractDataclass):
amortized_computation_config: MatrixFunctionConfig
num_tolerated_failed_amortized_computations: int = 3

def __post_init__(self) -> None:
if self.num_tolerated_failed_amortized_computations < 0:
raise ValueError(
f"Invalid num_tolerated_failed_amortized_computations value: {self.num_tolerated_failed_amortized_computations}. Must be >= 0."
)


@dataclass(kw_only=True)
class ShampooPreconditionerConfig(PreconditionerConfig):
Expand Down
82 changes: 81 additions & 1 deletion distributed_shampoo/tests/shampoo_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,23 @@

import re
import unittest
from typing import Type
from abc import ABC, abstractmethod
from typing import Generic, Type, TypeVar

from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdamGraftingConfig,
EigenvalueCorrectedShampooPreconditionerConfig,
PreconditionerConfig,
RMSpropGraftingConfig,
ShampooPreconditionerConfig,
)
from matrix_functions_types import (
DefaultEigenConfig,
DefaultEighConfig,
EigenvectorConfig,
MatrixFunctionConfig,
RootInvConfig,
)


Expand Down Expand Up @@ -69,3 +80,72 @@ def _get_grafting_config_type(
self,
) -> Type[RMSpropGraftingConfig] | Type[AdamGraftingConfig]:
return AdamGraftingConfig


PreconditionerConfigType = TypeVar(
"PreconditionerConfigType", bound=Type[PreconditionerConfig]
)
AmortizedComputationConfigType = TypeVar(
"AmortizedComputationConfigType", bound=MatrixFunctionConfig
)


class AbstractPreconditionerConfigTest:
class PreconditionerConfigTest(
ABC,
unittest.TestCase,
Generic[PreconditionerConfigType, AmortizedComputationConfigType],
):
def test_illegal_num_tolerated_failed_amortized_computations(self) -> None:
num_tolerated_failed_amortized_computations = -1
with (
self.assertRaisesRegex(
ValueError,
re.escape(
f"Invalid num_tolerated_failed_amortized_computations value: "
f"{num_tolerated_failed_amortized_computations}. Must be >= 0."
),
),
):
self._get_preconditioner_config_type()(
amortized_computation_config=self._get_amortized_computation_config(),
num_tolerated_failed_amortized_computations=num_tolerated_failed_amortized_computations,
)

@abstractmethod
def _get_preconditioner_config_type(
self,
) -> PreconditionerConfigType: ...

@abstractmethod
def _get_amortized_computation_config(
self,
) -> AmortizedComputationConfigType: ...


class ShampooPreconditionerConfigTest(
AbstractPreconditionerConfigTest.PreconditionerConfigTest[
Type[ShampooPreconditionerConfig], RootInvConfig
]
):
def _get_amortized_computation_config(self) -> RootInvConfig:
return DefaultEigenConfig

def _get_preconditioner_config_type(
self,
) -> Type[ShampooPreconditionerConfig]:
return ShampooPreconditionerConfig


class EigenvalueCorrectedShampooPreconditionerConfigTest(
AbstractPreconditionerConfigTest.PreconditionerConfigTest[
Type[EigenvalueCorrectedShampooPreconditionerConfig], EigenvectorConfig
]
):
def _get_amortized_computation_config(self) -> EigenvectorConfig:
return DefaultEighConfig

def _get_preconditioner_config_type(
self,
) -> Type[EigenvalueCorrectedShampooPreconditionerConfig]:
return EigenvalueCorrectedShampooPreconditionerConfig

0 comments on commit 03245f5

Please sign in to comment.