From 89fd01d609019ec5bced42e32be3cdccfa59bab6 Mon Sep 17 00:00:00 2001 From: Haoran Zhang Date: Mon, 4 Nov 2024 16:35:16 -0800 Subject: [PATCH 1/4] fix world rank Summary: as title Created from CodeHub with https://fburl.com/edit-in-codehub Reviewed By: hjmshi Differential Revision: D65452009 fbshipit-source-id: 3200c7d19602c7c864e9e64b17c6cccd343ff637 --- distributed_shampoo/examples/default_cifar10_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed_shampoo/examples/default_cifar10_example.py b/distributed_shampoo/examples/default_cifar10_example.py index 0fa5024..5643806 100644 --- a/distributed_shampoo/examples/default_cifar10_example.py +++ b/distributed_shampoo/examples/default_cifar10_example.py @@ -104,7 +104,7 @@ def train_default_model( # instantiate data loader. Note that this is a single GPU training example, # so we do not need to instantiate a sampler. - data_loader, _ = get_data_loader_and_sampler(args.data_path, 1, 1, args.batch_size) + data_loader, _ = get_data_loader_and_sampler(args.data_path, 1, 0, args.batch_size) # instantiate optimizer (SGD, Adam, DistributedShampoo) optimizer = instantiate_optimizer( From 39370a97085e5b3616f58e706a5e4228d192ea4d Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 5 Nov 2024 11:04:51 -0800 Subject: [PATCH 2/4] Fix `EigenvalueCorrectedShampooPreconditionerListTest` (#35) Summary: `EigenvalueCorrectedShampooPreconditionerList` has to be added to the `isinstance` check explicitly since it does not inherit from `ShampooPreconditionerList`. This increases code coverage since [this line](https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/utils/shampoo_preconditioner_list.py#L1232-L1235) and [this line](https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/utils/shampoo_preconditioner_list.py#L1245-L1249) are also covered. Pull Request resolved: https://github.com/facebookresearch/optimizers/pull/35 Reviewed By: anana10c Differential Revision: D65486426 Pulled By: tsunghsienlee fbshipit-source-id: 7a5eb4275e591434218433fad7334fed4b2e5c71 --- .../utils/tests/shampoo_preconditioner_list_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 56e9687..fac5c4d 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -20,6 +20,7 @@ from distributed_shampoo.utils.shampoo_block_info import BlockInfo from distributed_shampoo.utils.shampoo_preconditioner_list import ( AdagradPreconditionerList, + BaseShampooPreconditionerList, DequantizePreconditionersContext, EigenvalueCorrectedShampooPreconditionerList, PreconditionerList, @@ -74,9 +75,9 @@ def _test_update_preconditioners_and_precondition( preconditioner_list.update_preconditioners( masked_grad_list=masked_grad_list, step=torch.tensor(step), - # Only compute the new layerwise direction when the update_preconditioners() reach the last step. + # Only update the complete preconditioner during the last call to update_preconditioners(). perform_amortized_computation=isinstance( - preconditioner_list, ShampooPreconditionerList + preconditioner_list, BaseShampooPreconditionerList ) and step == len(masked_grad_lists), ) From 894320b3c28b23a55b9c573ea3b03580dc0cc697 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 5 Nov 2024 11:07:18 -0800 Subject: [PATCH 3/4] Add support for specifying a `PreconditionerComputationType` in the examples (#34) Summary: Also, to resolve an error I had to change the rank that is passed to `get_data_loader_and_sampler` in `default_cifar10_example.py` from 1 to 0. Pull Request resolved: https://github.com/facebookresearch/optimizers/pull/34 Reviewed By: anana10c Differential Revision: D65486358 Pulled By: tsunghsienlee fbshipit-source-id: 9c13ab325b43da66d22b9a73e9e78afdc07e198a --- .../examples/ddp_cifar10_example.py | 3 + .../examples/default_cifar10_example.py | 3 + .../examples/fsdp_cifar10_example.py | 3 + .../examples/fully_shard_cifar10_example.py | 3 + .../examples/hsdp_cifar10_example.py | 3 + distributed_shampoo/examples/trainer_utils.py | 62 +++++++++++++++++++ 6 files changed, 77 insertions(+) diff --git a/distributed_shampoo/examples/ddp_cifar10_example.py b/distributed_shampoo/examples/ddp_cifar10_example.py index 6140ccb..910cea4 100644 --- a/distributed_shampoo/examples/ddp_cifar10_example.py +++ b/distributed_shampoo/examples/ddp_cifar10_example.py @@ -125,12 +125,15 @@ computation_dtype=args.computation_dtype.value, factor_matrix_dtype=args.factor_matrix_dtype.value, inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value, + corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value, + factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value, filtered_grad_dtype=args.filtered_grad_dtype.value, momentum_dtype=args.momentum_dtype.value, grafting_state_dtype=args.grafting_state_dtype.value, ), use_protected_eigh=args.use_protected_eigh, track_root_inv_residuals=args.track_root_inv_residuals, + preconditioner_computation_type=args.preconditioner_computation_type, ) # checks for checkpointing diff --git a/distributed_shampoo/examples/default_cifar10_example.py b/distributed_shampoo/examples/default_cifar10_example.py index 5643806..176c01f 100644 --- a/distributed_shampoo/examples/default_cifar10_example.py +++ b/distributed_shampoo/examples/default_cifar10_example.py @@ -135,12 +135,15 @@ def train_default_model( computation_dtype=args.computation_dtype.value, factor_matrix_dtype=args.factor_matrix_dtype.value, inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value, + corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value, + factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value, filtered_grad_dtype=args.filtered_grad_dtype.value, momentum_dtype=args.momentum_dtype.value, grafting_state_dtype=args.grafting_state_dtype.value, ), use_protected_eigh=args.use_protected_eigh, track_root_inv_residuals=args.track_root_inv_residuals, + preconditioner_computation_type=args.preconditioner_computation_type, ) # train model diff --git a/distributed_shampoo/examples/fsdp_cifar10_example.py b/distributed_shampoo/examples/fsdp_cifar10_example.py index 653415b..483f270 100644 --- a/distributed_shampoo/examples/fsdp_cifar10_example.py +++ b/distributed_shampoo/examples/fsdp_cifar10_example.py @@ -120,12 +120,15 @@ computation_dtype=args.computation_dtype.value, factor_matrix_dtype=args.factor_matrix_dtype.value, inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value, + corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value, + factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value, filtered_grad_dtype=args.filtered_grad_dtype.value, momentum_dtype=args.momentum_dtype.value, grafting_state_dtype=args.grafting_state_dtype.value, ), use_protected_eigh=args.use_protected_eigh, track_root_inv_residuals=args.track_root_inv_residuals, + preconditioner_computation_type=args.preconditioner_computation_type, ) # train model diff --git a/distributed_shampoo/examples/fully_shard_cifar10_example.py b/distributed_shampoo/examples/fully_shard_cifar10_example.py index 20b48bb..f66be9c 100644 --- a/distributed_shampoo/examples/fully_shard_cifar10_example.py +++ b/distributed_shampoo/examples/fully_shard_cifar10_example.py @@ -137,12 +137,15 @@ def create_model_and_optimizer_and_loss_fn(args, device): computation_dtype=args.computation_dtype.value, factor_matrix_dtype=args.factor_matrix_dtype.value, inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value, + corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value, + factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value, filtered_grad_dtype=args.filtered_grad_dtype.value, momentum_dtype=args.momentum_dtype.value, grafting_state_dtype=args.grafting_state_dtype.value, ), use_protected_eigh=args.use_protected_eigh, track_root_inv_residuals=args.track_root_inv_residuals, + preconditioner_computation_type=args.preconditioner_computation_type, ) return model, optimizer, loss_function diff --git a/distributed_shampoo/examples/hsdp_cifar10_example.py b/distributed_shampoo/examples/hsdp_cifar10_example.py index 64ea762..b162f02 100644 --- a/distributed_shampoo/examples/hsdp_cifar10_example.py +++ b/distributed_shampoo/examples/hsdp_cifar10_example.py @@ -135,12 +135,15 @@ computation_dtype=args.computation_dtype.value, factor_matrix_dtype=args.factor_matrix_dtype.value, inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value, + corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value, + factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value, filtered_grad_dtype=args.filtered_grad_dtype.value, momentum_dtype=args.momentum_dtype.value, grafting_state_dtype=args.grafting_state_dtype.value, ), use_protected_eigh=args.use_protected_eigh, track_root_inv_residuals=args.track_root_inv_residuals, + preconditioner_computation_type=args.preconditioner_computation_type, ) # train model diff --git a/distributed_shampoo/examples/trainer_utils.py b/distributed_shampoo/examples/trainer_utils.py index 9efe9fb..6d909f7 100644 --- a/distributed_shampoo/examples/trainer_utils.py +++ b/distributed_shampoo/examples/trainer_utils.py @@ -31,6 +31,13 @@ RMSpropGraftingConfig, SGDGraftingConfig, ) +from matrix_functions_types import ( + CoupledHigherOrderConfig, + CoupledNewtonConfig, + EigenConfig, + EighEigenvalueCorrectionConfig, + PreconditionerComputationConfig, +) from torch import nn from torchvision import datasets, transforms @@ -62,6 +69,13 @@ class GraftingType(enum.Enum): ADAM = 4 +class PreconditionerComputationType(enum.Enum): + EIGEN_ROOT_INV = 0 + COUPLED_NEWTON_ROOT_INV = 1 + COUPLED_HIGHER_ORDER_ROOT_INV = 2 + EIGH_EIGENVALUE_CORRECTION = 3 + + ###### ARGPARSER ###### def enum_type_parse(s: str, enum_type: enum.Enum): try: @@ -195,6 +209,12 @@ def get_args(): action="store_true", help="Use debug mode for examining root inverse residuals.", ) + parser.add_argument( + "--preconditioner-computation-type", + type=lambda t: enum_type_parse(t, PreconditionerComputationType), + default=PreconditionerComputationType.EIGEN_ROOT_INV, + help="Preconditioner computation method for Shampoo.", + ) # Arguments for grafting. parser.add_argument( @@ -235,6 +255,18 @@ def get_args(): default=DType.FP32, help="Data type for storing Shampoo inverse factor matrices.", ) + parser.add_argument( + "--corrected-eigenvalues-dtype", + type=lambda t: enum_type_parse(t, DType), + default=DType.FP32, + help="Data type for storing corrected eigenvalues of Shampoo preconditioner.", + ) + parser.add_argument( + "--factor-matrix-eigenvectors-dtype", + type=lambda t: enum_type_parse(t, DType), + default=DType.FP32, + help="Data type for storing Shampoo factor matrices eigenvectors.", + ) parser.add_argument( "--filtered-grad-dtype", type=lambda t: enum_type_parse(t, DType), @@ -410,6 +442,7 @@ def instantiate_optimizer( precision_config: Optional[PrecisionConfig], use_protected_eigh: bool, track_root_inv_residuals: bool, + preconditioner_computation_type: PreconditionerComputationType, ) -> torch.optim.Optimizer: if optimizer_type == OptimizerType.SGD: optimizer = torch.optim.SGD( @@ -464,6 +497,9 @@ def instantiate_optimizer( precision_config=precision_config, use_protected_eigh=use_protected_eigh, track_root_inv_residuals=track_root_inv_residuals, + preconditioner_computation_config=instantiate_preconditioner_computation_config( + preconditioner_computation_type + ), ) else: raise ValueError(f"Invalid OptimizerType {optimizer_type}!") @@ -501,6 +537,32 @@ def instantiate_grafting_config( raise ValueError(f"Invalid GraftingType {grafting_type}!") +def instantiate_preconditioner_computation_config( + preconditioner_computation_type: PreconditionerComputationType, +) -> PreconditionerComputationConfig: + if preconditioner_computation_type == PreconditionerComputationType.EIGEN_ROOT_INV: + return EigenConfig() + elif ( + preconditioner_computation_type + == PreconditionerComputationType.COUPLED_NEWTON_ROOT_INV + ): + return CoupledNewtonConfig() + elif ( + preconditioner_computation_type + == PreconditionerComputationType.COUPLED_HIGHER_ORDER_ROOT_INV + ): + return CoupledHigherOrderConfig() + elif ( + preconditioner_computation_type + == PreconditionerComputationType.EIGH_EIGENVALUE_CORRECTION + ): + return EighEigenvalueCorrectionConfig() + else: + raise ValueError( + f"Invalid PreconditionerComputationType {preconditioner_computation_type}!" + ) + + ###### DATA LOADER ###### def get_data_loader_and_sampler( data_path: str, world_size: int, rank: int, local_batch_size: int From 397ad17fc995f631efacf383744bf6bb3f94a056 Mon Sep 17 00:00:00 2001 From: Tsung-Hsien Lee Date: Tue, 5 Nov 2024 12:07:00 -0800 Subject: [PATCH 4/4] Open-sourced update on 11/05/2024 Summary: 1. Refactor `shampoo_preconditioner_list_test.py` for sharing the test fixtures. 2. Some small code improvements. Reviewed By: anana10c Differential Revision: D65494156 fbshipit-source-id: 0fca4d824a19612f1cbed403b64947f6a6be4296 --- .../gpu_tests/shampoo_grafting_test.py | 11 +- .../tests/shampoo_test_utils.py | 16 +- .../tests/shampoo_preconditioner_list_test.py | 545 ++++++------------ 3 files changed, 201 insertions(+), 371 deletions(-) diff --git a/distributed_shampoo/gpu_tests/shampoo_grafting_test.py b/distributed_shampoo/gpu_tests/shampoo_grafting_test.py index a0c7d2a..229eee6 100644 --- a/distributed_shampoo/gpu_tests/shampoo_grafting_test.py +++ b/distributed_shampoo/gpu_tests/shampoo_grafting_test.py @@ -9,6 +9,7 @@ #!/usr/bin/env python3 +import math import unittest from functools import partial from itertools import product @@ -118,7 +119,7 @@ def test_adagrad_grafting_on_quadratic(self) -> None: momentum=0.0, max_preconditioner_dim=10, precondition_frequency=1, - start_preconditioning_step=1000, + start_preconditioning_step=math.inf, use_decoupled_weight_decay=False, grafting_config=AdaGradGraftingConfig( epsilon=1e-10, @@ -153,7 +154,7 @@ def test_adam_grafting_on_quadratic(self) -> None: momentum=0.0, max_preconditioner_dim=10, precondition_frequency=1, - start_preconditioning_step=1000, + start_preconditioning_step=math.inf, use_decoupled_weight_decay=False, grafting_config=AdamGraftingConfig( beta2=0.999, @@ -189,7 +190,7 @@ def test_adamw_grafting_on_quadratic(self) -> None: momentum=0.0, max_preconditioner_dim=10, precondition_frequency=1, - start_preconditioning_step=1000, + start_preconditioning_step=math.inf, use_decoupled_weight_decay=True, grafting_config=AdamGraftingConfig( beta2=0.999, @@ -228,7 +229,7 @@ def test_rmsprop_grafting_on_quadratic(self) -> None: momentum=0.0, max_preconditioner_dim=10, precondition_frequency=1, - start_preconditioning_step=1000, + start_preconditioning_step=math.inf, use_decoupled_weight_decay=False, grafting_config=RMSpropGraftingConfig( beta2=0.99, @@ -269,7 +270,7 @@ def test_sgd_grafting_on_quadratic(self) -> None: epsilon=1e-10, max_preconditioner_dim=10, precondition_frequency=1, - start_preconditioning_step=1000, + start_preconditioning_step=math.inf, use_nesterov=use_nesterov, use_decoupled_weight_decay=False, grafting_config=SGDGraftingConfig(), diff --git a/distributed_shampoo/tests/shampoo_test_utils.py b/distributed_shampoo/tests/shampoo_test_utils.py index 51825e4..7c43d44 100644 --- a/distributed_shampoo/tests/shampoo_test_utils.py +++ b/distributed_shampoo/tests/shampoo_test_utils.py @@ -8,7 +8,7 @@ """ import itertools -from typing import Optional +from functools import reduce import torch from torch import nn @@ -18,7 +18,7 @@ class _ModelWithLinearAndDeadLayers(nn.Module): def __init__( self, model_linear_layers_dims: tuple[int, ...], - model_dead_layer_dims: Optional[tuple[int, ...]], + model_dead_layer_dims: tuple[int, ...] | None, bias: bool = False, ) -> None: super().__init__() @@ -38,15 +38,13 @@ def __init__( ) def forward(self, x: torch.Tensor) -> torch.Tensor: - for linear_layer in self.linear_layers: - x = linear_layer(x) - return x + return reduce(lambda x, layer: layer(x), self.linear_layers, x) def construct_training_problem( model_linear_layers_dims: tuple[int, ...], - model_dead_layer_dims: Optional[tuple[int, ...]] = (10, 10), - device: Optional[torch.device] = None, + model_dead_layer_dims: tuple[int, ...] | None = (10, 10), + device: torch.device | None = None, bias: bool = False, fill: float | tuple[float, ...] = 0.0, ) -> tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]: @@ -55,8 +53,8 @@ def construct_training_problem( Args: model_linear_layers_dims (tuple[int, ...]): The dimensions of the model linear layers. - model_dead_layer_dims (Optional[tuple[int, ...]]): The dimensions of the model dead linear layers. (Default: (10, 10)) - device (Optional[torch.device]): The device to use. (Default: None) + model_dead_layer_dims (tuple[int, ...] | None): The dimensions of the model dead linear layers. (Default: (10, 10)) + device (torch.device | None): The device to use. (Default: None) bias (bool): Whether to use bias in the linear (non-dead) layers. (Default: False) fill (float | tuple[float, ...]): The value(s) to fill the model parameters. If a tuple, each element should correspond to one layer. (Default: 0.0) diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index fac5c4d..9c92161 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -7,6 +7,7 @@ """ +import abc import re import unittest from types import ModuleType @@ -274,7 +275,187 @@ def test_compress_preconditioner_list(self) -> None: self._test_compress_preconditioner_list(expected_compress_list_call_count=1) -class ShampooPreconditionerListTest(AdagradPreconditionerListTest): +# Use outer class as wrapper to avoid running the abstract test. +class AbstractTest: + class BaseShampooPreconditionerListTest(abc.ABC, AdagradPreconditionerListTest): + @abc.abstractmethod + def _amortized_computation_function(self) -> str: ... + + @abc.abstractmethod + def _instantiate_preconditioner_list( + self, **kwargs: Any + ) -> PreconditionerList: ... + + def _test_raise_invalid_value_in_factor_matrix( + self, invalid_value: float + ) -> None: + with DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), self.assertRaisesRegex( + PreconditionerValueError, + re.escape(f"Encountered {str(invalid_value)} values in factor matrix"), + ): + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([invalid_value, invalid_value]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[invalid_value, invalid_value]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) + + # Because nan as the input of self._preconditioner_list.update_preconditioners() would change the internal state to nan (and stayed as nan even after other updates), + # we need to test the cases of nan and inf separately. + def test_raise_inf_in_factor_matrix(self) -> None: + self._test_raise_invalid_value_in_factor_matrix(invalid_value=torch.inf) + + def test_raise_nan_in_factor_matrix(self) -> None: + self._test_raise_invalid_value_in_factor_matrix(invalid_value=torch.nan) + + def test_raise_nan_and_inf_in_inv_factor_matrix_amortized_computation( + self, + ) -> None: + for invalid_value in (torch.nan, torch.inf): + with DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), self.subTest(invalid_value=invalid_value), self.assertRaisesRegex( + PreconditionerValueError, + re.escape("Encountered nan or inf values in"), + ), mock.patch.object( + shampoo_preconditioner_list, + self._amortized_computation_function(), + side_effect=(torch.tensor([invalid_value]),), + ) as mock_amortized_computation: + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) + mock_amortized_computation.assert_called_once() + + def test_amortized_computation_internal_failure(self) -> None: + with mock.patch.object( + shampoo_preconditioner_list, + self._amortized_computation_function(), + # Simulate the situation throws an exception (not nan and inf) to test the warning + side_effect=ZeroDivisionError, + ) as mock_amortized_computation: + with DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), self.assertLogs(level="WARNING") as cm: + # Because use_protected_eigh is True, we expect the warning to be logged. + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) + self.assertCountEqual( + # Only extracts the first sentence in the warning message for simple comparison. + [r.msg.split(". ", maxsplit=1)[0] for r in cm.records], + [ + "Matrix computation failed for factor matrix 0.block_0.0 with exception=ZeroDivisionError()", + "Matrix computation failed for factor matrix 1.block_0.0 with exception=ZeroDivisionError()", + "Matrix computation failed for factor matrix 1.block_0.1 with exception=ZeroDivisionError()", + "Matrix computation failed for factor matrix 1.block_1.0 with exception=ZeroDivisionError()", + "Matrix computation failed for factor matrix 1.block_1.1 with exception=ZeroDivisionError()", + ], + ) + mock_amortized_computation.assert_called() + mock_amortized_computation.reset_mock() + + # Turn off use_protected_eigh and expect ZeroDivisionError to be logged. + self._preconditioner_list = self._instantiate_preconditioner_list( + use_protected_eigh=False, + ) + with DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), self.assertRaises(ZeroDivisionError): + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) + mock_amortized_computation.assert_called() + + # Note: This is needed for type checking to infer the type of argument into mock.patch.object. + shampoo_preconditioner_list_module: ModuleType = shampoo_preconditioner_list + + @mock.patch.object( + shampoo_preconditioner_list_module, + "check_diagonal", + return_value=False, + ) + def test_amortized_computation_factor_matrix_non_diagonal( + self, mock_check_diagonal: mock.Mock + ) -> None: + self._preconditioner_list = self._instantiate_preconditioner_list( + epsilon=1.0 + ) + with DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), self.assertLogs( + level="DEBUG", + ) as cm: + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) + self.assertCountEqual( + [r.msg for r in cm.records], + [ + "Factor matrix 0.block_0.0 is not diagonal.", + "Factor matrix 1.block_0.0 is not diagonal.", + "Factor matrix 1.block_0.1 is not diagonal.", + "Factor matrix 1.block_1.0 is not diagonal.", + "Factor matrix 1.block_1.1 is not diagonal.", + ], + ) + mock_check_diagonal.assert_called() + + def test_numel_list(self) -> None: + self.assertEqual(self._preconditioner_list.numel_list, (8, 16, 10)) + + def test_dims_list(self) -> None: + self.assertEqual( + self._preconditioner_list.dims_list, + (torch.Size([2]), torch.Size([2, 2]), torch.Size([1, 2])), + ) + + def test_num_bytes_list(self) -> None: + self.assertEqual(self._preconditioner_list.num_bytes_list, (48, 96, 60)) + + def test_numel(self) -> None: + self.assertEqual(self._preconditioner_list.numel(), 34) + + def test_num_bytes(self) -> None: + self.assertEqual(self._preconditioner_list.num_bytes(), 204) + + def test_compress_preconditioner_list(self) -> None: + self._test_compress_preconditioner_list(expected_compress_list_call_count=3) + + +class ShampooPreconditionerListTest(AbstractTest.BaseShampooPreconditionerListTest): + def _amortized_computation_function(self) -> str: + return "matrix_inverse_root" + def _instantiate_preconditioner_list(self, **kwargs: Any) -> PreconditionerList: kwargs = { "beta2": 1.0, @@ -511,208 +692,6 @@ def test_inverse_roots_from_override( test_inverse_roots_from_override(inv_root_override=2) test_inverse_roots_from_override(inv_root_override=[2, 2, 2]) - def test_raise_inf_in_factor_matrix_compute_root_inverse(self) -> None: - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ): - with self.assertRaisesRegex( - PreconditionerValueError, - re.escape("Encountered inf values in factor matrix"), - ): - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([torch.inf, torch.inf]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[torch.inf, torch.inf]]), - ), - step=torch.tensor(1), - perform_amortized_computation=True, - ) - - def test_raise_nan_in_factor_matrix_compute_root_inverse(self) -> None: - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ): - with self.assertRaisesRegex( - PreconditionerValueError, - re.escape("Encountered nan values in factor matrix"), - ): - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([torch.nan, torch.nan]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[torch.nan, torch.nan]]), - ), - step=torch.tensor(1), - perform_amortized_computation=True, - ) - - # Note: This is needed for pyre to infer the type of argument into mock.patch.object. - shampoo_preconditioner_list_module: ModuleType = shampoo_preconditioner_list - - @mock.patch.object( - shampoo_preconditioner_list_module, - "matrix_inverse_root", - side_effect=(torch.tensor([torch.inf]),), - ) - def test_raise_inf_in_inv_factor_matrix_compute_root_inverse( - self, mock_matrix_inverse_root: mock.Mock - ) -> None: - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.assertRaisesRegex( - PreconditionerValueError, - re.escape("Encountered nan or inf values in inverse factor matrix"), - ): - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([1.0, 0.0]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[1.0, 0.0]]), - ), - step=torch.tensor(1), - perform_amortized_computation=True, - ) - mock_matrix_inverse_root.assert_called_once() - - @mock.patch.object( - shampoo_preconditioner_list_module, - "matrix_inverse_root", - side_effect=(torch.tensor([torch.nan]),), - ) - def test_raise_nan_in_inv_factor_matrix_compute_root_inverse( - self, mock_matrix_inverse_root: mock.Mock - ) -> None: - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.assertRaisesRegex( - PreconditionerValueError, - re.escape("Encountered nan or inf values in inverse factor matrix"), - ): - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([1.0, 0.0]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[1.0, 0.0]]), - ), - step=torch.tensor(1), - perform_amortized_computation=True, - ) - mock_matrix_inverse_root.assert_called_once() - - @mock.patch.object( - shampoo_preconditioner_list_module, - "matrix_inverse_root", - # Simulate the situation matrix_inverse_root throws an exception (not nan and inf) to test the warning - side_effect=ZeroDivisionError, - ) - def test_matrix_compute_root_inverse_internal_failure( - self, mock_matrix_inverse_root: mock.Mock - ) -> None: - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.assertLogs(level="WARNING") as cm: - # Because use_protected_eigh is True, we expect the warning to be logged. - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([1.0, 0.0]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[1.0, 0.0]]), - ), - step=torch.tensor(1), - perform_amortized_computation=True, - ) - self.assertCountEqual( - [r.msg for r in cm.records], - [ - "Matrix computation failed for factor matrix 0.block_0.0 with exception=ZeroDivisionError()." - " Using previous inversed factor matrix and continuing...", - "Matrix computation failed for factor matrix 1.block_0.0 with exception=ZeroDivisionError()." - " Using previous inversed factor matrix and continuing...", - "Matrix computation failed for factor matrix 1.block_0.1 with exception=ZeroDivisionError()." - " Using previous inversed factor matrix and continuing...", - "Matrix computation failed for factor matrix 1.block_1.0 with exception=ZeroDivisionError()." - " Using previous inversed factor matrix and continuing...", - "Matrix computation failed for factor matrix 1.block_1.1 with exception=ZeroDivisionError()." - " Using previous inversed factor matrix and continuing...", - ], - ) - mock_matrix_inverse_root.assert_called() - - # Turn off use_protected_eigh and expect ZeroDivisionError to be logged. - self._preconditioner_list = self._instantiate_preconditioner_list( - use_protected_eigh=False, - ) - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.assertRaises(ZeroDivisionError): - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([1.0, 0.0]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[1.0, 0.0]]), - ), - step=torch.tensor(1), - perform_amortized_computation=True, - ) - mock_matrix_inverse_root.assert_called() - - @mock.patch.object( - shampoo_preconditioner_list_module, - "check_diagonal", - return_value=False, - ) - def test_matrix_compute_root_inverse_factor_matrix_non_diagonal( - self, mock_check_diagonal: mock.Mock - ) -> None: - self._preconditioner_list = self._instantiate_preconditioner_list(epsilon=1.0) - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.assertLogs( - level="DEBUG", - ) as cm: - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([1.0, 0.0]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[1.0, 0.0]]), - ), - step=torch.tensor(1), - perform_amortized_computation=True, - ) - self.assertCountEqual( - [r.msg for r in cm.records], - [ - "Factor matrix 0.block_0.0 is not diagonal.", - "Factor matrix 1.block_0.0 is not diagonal.", - "Factor matrix 1.block_0.1 is not diagonal.", - "Factor matrix 1.block_1.0 is not diagonal.", - "Factor matrix 1.block_1.1 is not diagonal.", - ], - ) - mock_check_diagonal.assert_called() - - def test_numel_list(self) -> None: - self.assertEqual(self._preconditioner_list.numel_list, (8, 16, 10)) - - def test_dims_list(self) -> None: - self.assertEqual( - self._preconditioner_list.dims_list, - (torch.Size([2]), torch.Size([2, 2]), torch.Size([1, 2])), - ) - - def test_num_bytes_list(self) -> None: - self.assertEqual(self._preconditioner_list.num_bytes_list, (48, 96, 60)) - - def test_numel(self) -> None: - self.assertEqual(self._preconditioner_list.numel(), 34) - - def test_num_bytes(self) -> None: - self.assertEqual(self._preconditioner_list.num_bytes(), 204) - - def test_compress_preconditioner_list(self) -> None: - self._test_compress_preconditioner_list(expected_compress_list_call_count=3) - def test_compute_root_inverse_residuals(self) -> None: """ Create a factor matrix of size 2x2 by updating preconditioners in two steps: @@ -757,7 +736,12 @@ def test_compute_root_inverse_residuals(self) -> None: self.assertTupleEqual(relative_residuals, expected_relative_residuals) -class EigenvalueCorrectedShampooPreconditionerListTest(AdagradPreconditionerListTest): +class EigenvalueCorrectedShampooPreconditionerListTest( + AbstractTest.BaseShampooPreconditionerListTest +): + def _amortized_computation_function(self) -> str: + return "matrix_eigenvectors" + def _instantiate_preconditioner_list(self, **kwargs: Any) -> PreconditionerList: kwargs = { "beta2": 1.0, @@ -1006,156 +990,3 @@ def test_inverse_roots_from_override( test_inverse_roots_from_override(inv_root_override=1) test_inverse_roots_from_override(inv_root_override=[1, 1, 1]) - - """Tests for compute_preconditioner_eigenvectors.""" - - def test_raise_inf_in_factor_matrix_compute_preconditioner_eigenvectors( - self, - ) -> None: - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ): - with self.assertRaisesRegex( - ValueError, - re.escape("Encountered inf values in factor matrix"), - ): - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([torch.inf, torch.inf]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[torch.inf, torch.inf]]), - ), - step=torch.tensor(1), - perform_amortized_computation=True, - ) - - def test_raise_nan_in_factor_matrix_compute_preconditioner_eigenvectors( - self, - ) -> None: - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ): - with self.assertRaisesRegex( - ValueError, - re.escape("Encountered nan values in factor matrix"), - ): - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([torch.nan, torch.nan]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[torch.nan, torch.nan]]), - ), - step=torch.tensor(1), - perform_amortized_computation=True, - ) - - # Note: This is needed for pyre to infer the type of argument into mock.patch.object. - shampoo_preconditioner_list_module: ModuleType = shampoo_preconditioner_list - - @mock.patch.object( - shampoo_preconditioner_list_module, - "matrix_eigenvectors", - side_effect=(torch.tensor([torch.inf]),), - ) - def test_raise_inf_in_inv_factor_matrix_compute_preconditioner_eigenvectors( - self, mock_matrix_eigenvectors: mock.Mock - ) -> None: - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.assertRaisesRegex( - PreconditionerValueError, - re.escape("Encountered nan or inf values in eigenvectors of factor matrix"), - ): - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([1.0, 0.0]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[1.0, 0.0]]), - ), - step=torch.tensor(1), - perform_amortized_computation=True, - ) - mock_matrix_eigenvectors.assert_called_once() - - @mock.patch.object( - shampoo_preconditioner_list_module, - "matrix_eigenvectors", - side_effect=(torch.tensor([torch.nan]),), - ) - def test_raise_nan_in_inv_factor_matrix_compute_preconditioner_eigenvectors( - self, mock_matrix_eigenvectors: mock.Mock - ) -> None: - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.assertRaisesRegex( - PreconditionerValueError, - re.escape("Encountered nan or inf values in eigenvectors of factor matrix"), - ): - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([1.0, 0.0]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[1.0, 0.0]]), - ), - step=torch.tensor(1), - perform_amortized_computation=True, - ) - mock_matrix_eigenvectors.assert_called_once() - - @mock.patch.object( - shampoo_preconditioner_list_module, - "check_diagonal", - return_value=False, - ) - def test_matrix_compute_preconditioner_eigenvectors_factor_matrix_non_diagonal( - self, mock_check_diagonal: mock.Mock - ) -> None: - self._preconditioner_list = self._instantiate_preconditioner_list(epsilon=1.0) - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.assertLogs( - level="DEBUG", - ) as cm: - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([1.0, 0.0]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[1.0, 0.0]]), - ), - step=torch.tensor(1), - perform_amortized_computation=True, - ) - self.assertCountEqual( - [r.msg for r in cm.records], - [ - "Factor matrix 0.block_0.0 is not diagonal.", - "Factor matrix 1.block_0.0 is not diagonal.", - "Factor matrix 1.block_0.1 is not diagonal.", - "Factor matrix 1.block_1.0 is not diagonal.", - "Factor matrix 1.block_1.1 is not diagonal.", - ], - ) - mock_check_diagonal.assert_called() - - """End of tests for compute_preconditioner_eigenvectors.""" - - def test_numel_list(self) -> None: - self.assertEqual(self._preconditioner_list.numel_list, (8, 16, 10)) - - def test_dims_list(self) -> None: - self.assertEqual( - self._preconditioner_list.dims_list, - (torch.Size([2]), torch.Size([2, 2]), torch.Size([1, 2])), - ) - - def test_num_bytes_list(self) -> None: - self.assertEqual(self._preconditioner_list.num_bytes_list, (48, 96, 60)) - - def test_numel(self) -> None: - self.assertEqual(self._preconditioner_list.numel(), 34) - - def test_num_bytes(self) -> None: - self.assertEqual(self._preconditioner_list.num_bytes(), 204) - - def test_compress_preconditioner_list(self) -> None: - self._test_compress_preconditioner_list(expected_compress_list_call_count=3)