Skip to content

Commit

Permalink
Add test for amortized computation failure tolerance
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Dec 9, 2024
1 parent 32c5df5 commit 65f4f27
Showing 1 changed file with 113 additions and 2 deletions.
115 changes: 113 additions & 2 deletions distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _test_compress_preconditioner_list(
) as mock_compress_quant_list,
):
# Count the number of list compressions at the preconditioner list level, including compressions of QuantizedTensorList.
# Each call to compress() under QuantizedTensorList counts once, though note that it calls compress_list() three times inside.
# Each call to compress() under QuantizedTensorList counts once, though note that it calls compress_list() four times inside.
self.assertIsNone(
self._preconditioner_list.compress_preconditioner_list(
local_grad_selector=(True,) * len(self._block_list)
Expand Down Expand Up @@ -327,6 +327,9 @@ def test_abstract_methods(self) -> None:
# Use outer class as wrapper to avoid running the abstract test.
class AbstractTest:
class BaseShampooPreconditionerListTest(abc.ABC, AdagradPreconditionerListTest):
# Number of calls to the amortized computation function per update.
NUM_AMORTIZED_COMPUTATION_CALLS = 5

@abc.abstractmethod
def _amortized_computation_function(self) -> str: ...

Expand Down Expand Up @@ -455,6 +458,114 @@ def test_amortized_computation_internal_failure(self) -> None:
)
mock_amortized_computation.assert_called()

def test_amortized_computation_failure_tolerance(self) -> None:
self._preconditioner_list = self._instantiate_preconditioner_list()
masked_grad_list0 = (
torch.tensor([1.0, 0.0]),
torch.eye(2) / torch.tensor(2.0).sqrt(),
torch.tensor([[1.0, 0.0]]),
)
masked_grad_list = (
torch.tensor([0.0, 1.0]),
torch.eye(2) / torch.tensor(2.0).sqrt(),
torch.tensor([[0.0, 1.0]]),
)

with mock.patch.object(
shampoo_preconditioner_list,
self._amortized_computation_function(),
side_effect=[
*(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS,
*(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS,
*(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS,
*(torch.tensor([1.0]),) * self.NUM_AMORTIZED_COMPUTATION_CALLS,
*(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS,
*(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS,
*(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS,
*(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS,
],
) as mock_amortized_computation:
with DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
):
step = 1
# Accumulate factor matrices for valid amortized computation.
self._preconditioner_list.update_preconditioners(
masked_grad_list=masked_grad_list0,
step=torch.tensor(step),
perform_amortized_computation=False,
)
self.assertEqual(mock_amortized_computation.call_count, 0)
step += 1

# Case 1: amortized computation fails less often than tolerance -> no error.
self._preconditioner_list.update_preconditioners(
masked_grad_list=masked_grad_list,
step=torch.tensor(step),
perform_amortized_computation=True,
)
self.assertEqual(
mock_amortized_computation.call_count,
self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 1),
)
step += 1

# Case 2: amortized computation fails exactly as often as tolerance (3) -> no error.
for _ in range(2):
self._preconditioner_list.update_preconditioners(
masked_grad_list=masked_grad_list,
step=torch.tensor(step),
perform_amortized_computation=True,
)
self.assertEqual(
mock_amortized_computation.call_count,
self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 1),
)
step += 1

# Case 3: amortized computation succeeds after tolerance hit (test reset) -> no error.
self._preconditioner_list.update_preconditioners(
masked_grad_list=masked_grad_list,
step=torch.tensor(step),
perform_amortized_computation=True,
)
self.assertEqual(
mock_amortized_computation.call_count,
self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 1),
)
step += 1

# Case 4: amortized computation fails more often than tolerance -> error.
for _ in range(3):
self._preconditioner_list.update_preconditioners(
masked_grad_list=masked_grad_list,
step=torch.tensor(step),
perform_amortized_computation=True,
)
self.assertEqual(
mock_amortized_computation.call_count,
self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 1),
)
step += 1
# At tolerance now.
with self.assertRaises(ValueError):
with self.assertLogs(level="ERROR") as log:
self._preconditioner_list.update_preconditioners(
masked_grad_list=masked_grad_list,
step=torch.tensor(step),
perform_amortized_computation=True,
)
self.assertIn(
"Exceeded tolerance (3) for number of failed amortized computations.",
log.output,
)
# The error will be raised for the first Kronecker factor, so the
# call expected count should only be increased by 1.
self.assertEqual(
mock_amortized_computation.call_count,
self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 2) + 1,
)

# Note: This is needed for type checking to infer the type of argument into mock.patch.object.
shampoo_preconditioner_list_module: ModuleType = shampoo_preconditioner_list

Expand Down Expand Up @@ -517,7 +628,7 @@ def test_num_bytes(self) -> None:
self.assertEqual(self._preconditioner_list.num_bytes(), 204)

def test_compress_preconditioner_list(self) -> None:
self._test_compress_preconditioner_list(expected_compress_list_call_count=3)
self._test_compress_preconditioner_list(expected_compress_list_call_count=4)


class ShampooPreconditionerListTest(AbstractTest.BaseShampooPreconditionerListTest):
Expand Down

0 comments on commit 65f4f27

Please sign in to comment.