Skip to content

Commit

Permalink
Make failure tracking coarser
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Dec 9, 2024
1 parent 03245f5 commit 527c35e
Showing 1 changed file with 65 additions and 49 deletions.
114 changes: 65 additions & 49 deletions distributed_shampoo/utils/shampoo_preconditioner_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,15 +768,14 @@ def _initialize_state_lists(
self._inv_root_override,
self._local_order_list,
)
self._local_failed_amortized_computation_counter_list: list[list[int]] = [
[0] * len(kronecker_factors.factor_matrices)
for kronecker_factors in self._local_kronecker_factors_list
]
self._local_failed_amortized_computation_counter_list: list[int] = [0] * len(
self._local_kronecker_factors_list
)

# Masked lists are the list of active preconditioners or values after filtering out gradients with None.
self._masked_order_list: tuple[int, ...] = self._local_order_list
self._masked_root_list: tuple[int, ...] = self._local_root_list
self._masked_failed_amortized_computation_counter_list: list[list[int]] = (
self._masked_failed_amortized_computation_counter_list: list[int] = (
self._local_failed_amortized_computation_counter_list
)
self._masked_kronecker_factors_list: tuple[
Expand Down Expand Up @@ -814,7 +813,7 @@ def compress_preconditioner_list(
self._masked_root_list: tuple[int, ...] = compress_list( # type: ignore[no-redef]
self._local_root_list, local_grad_selector
)
self._masked_failed_amortized_computation_counter_list: list[list[int]] = ( # type: ignore[no-redef]
self._masked_failed_amortized_computation_counter_list: list[int] = ( # type: ignore[no-redef]
list(
compress_list(
self._local_failed_amortized_computation_counter_list,
Expand Down Expand Up @@ -966,25 +965,25 @@ def _amortized_computation(self) -> None:
with profiler.record_function(
f"## {self.__class__.__name__}:{self._amortized_computation.__name__} ##"
):
for kronecker_factors, root, fail_counter_list in zip(
self._masked_kronecker_factors_list,
self._masked_root_list,
self._masked_failed_amortized_computation_counter_list,
strict=True,
for idx, (kronecker_factors, root) in enumerate(
zip(
self._masked_kronecker_factors_list,
self._masked_root_list,
strict=True,
)
):
for idx, (
success_tracker: list[bool] = []
for (
factor_matrix,
inv_factor_matrix,
is_factor_matrix_diagonal,
factor_matrix_index,
) in enumerate(
zip(
kronecker_factors.factor_matrices.dequantized_value,
kronecker_factors.inv_factor_matrices.dequantized_value,
kronecker_factors.is_factor_matrices_diagonal,
kronecker_factors.factor_matrix_indices,
strict=True,
)
) in zip(
kronecker_factors.factor_matrices.dequantized_value,
kronecker_factors.inv_factor_matrices.dequantized_value,
kronecker_factors.is_factor_matrices_diagonal,
kronecker_factors.factor_matrix_indices,
strict=True,
):
# Add epsilon term and incorporate bias correction.
bias_corrected_factor_matrix = (
Expand Down Expand Up @@ -1017,19 +1016,15 @@ def _amortized_computation(self) -> None:
epsilon=self._epsilon,
is_diagonal=bool(is_factor_matrix_diagonal),
).to(dtype=inv_factor_matrix.dtype)
# Reset counter for failed amortized computations.
fail_counter_list[idx] = 0
# Add success to success tracker.
success_tracker.append(True)
except Exception as exception:
# If self._use_protected_eigh is True, will reuse previous matrix if matrix inverse root computation fails.
if not self._use_protected_eigh:
raise exception
else:
# Increment counter for failed amortized computations.
fail_counter_list[idx] += 1
# Only reuse previous matrix if tolerance is not exceeded.
self._raise_exception_if_tolerance_exceeded(
fail_counter_list[idx], exception
)
# Add failure to success tracker.
success_tracker.append(False)
logger.warning(
f"Matrix computation failed for factor matrix {factor_matrix_index} "
f"with {exception=}. Using previous inverted factor matrix and continuing..."
Expand All @@ -1049,6 +1044,20 @@ def _amortized_computation(self) -> None:
)
inv_factor_matrix.copy_(computed_inv_factor_matrix)

if all(success_tracker):
# Reset counter for failed amortized computations.
self._masked_failed_amortized_computation_counter_list[idx] = 0
else:
# Increment counter for failed amortized computations.
self._masked_failed_amortized_computation_counter_list[idx] += 1
# Only reuse previous eigenvectors if tolerance is not exceeded.
self._raise_exception_if_tolerance_exceeded(
self._masked_failed_amortized_computation_counter_list[idx],
ValueError(
f"Exceeded tolerance for number of failed root inverse computations for {kronecker_factors.factor_matrix_indices}."
),
)

def dequantize_preconditioners(self) -> None:
with profiler.record_function(
f"## {self.__class__.__name__}:{self.dequantize_preconditioners.__name__} ##"
Expand Down Expand Up @@ -1308,24 +1317,21 @@ def _amortized_computation(self) -> None:
with profiler.record_function(
f"## {self.__class__.__name__}:{self._amortized_computation.__name__} ##"
):
for kronecker_factors, fail_counter_list in zip(
self._masked_kronecker_factors_list,
self._masked_failed_amortized_computation_counter_list,
strict=True,
for idx, kronecker_factors in enumerate(
self._masked_kronecker_factors_list
):
for idx, (
success_tracker: list[bool] = []
for (
factor_matrix,
factor_matrix_eigenvectors,
is_factor_matrix_diagonal,
factor_matrix_index,
) in enumerate(
zip(
kronecker_factors.factor_matrices.dequantized_value,
kronecker_factors.factor_matrices_eigenvectors.dequantized_value,
kronecker_factors.is_factor_matrices_diagonal,
kronecker_factors.factor_matrix_indices,
strict=True,
)
) in zip(
kronecker_factors.factor_matrices.dequantized_value,
kronecker_factors.factor_matrices_eigenvectors.dequantized_value,
kronecker_factors.is_factor_matrices_diagonal,
kronecker_factors.factor_matrix_indices,
strict=True,
):
BaseShampooPreconditionerList._check_factor_matrix_for_diagonality_nan_and_inf(
factor_matrix=factor_matrix,
Expand All @@ -1345,19 +1351,15 @@ def _amortized_computation(self) -> None:
eigenvector_computation_config=eigenvector_computation_config,
is_diagonal=bool(is_factor_matrix_diagonal),
)
# Reset counter for failed amortized computations.
fail_counter_list[idx] = 0
# Add success to success tracker.
success_tracker.append(True)
except Exception as exception:
# If self._use_protected_eigh is True, will reuse previous matrix if matrix eigenvector computation fails.
if not self._use_protected_eigh:
raise exception
else:
# Increment counter for failed amortized computations.
fail_counter_list[idx] += 1
# Only reuse previous matrix if tolerance is not exceeded.
self._raise_exception_if_tolerance_exceeded(
fail_counter_list[idx], exception
)
# Add failure to success tracker.
success_tracker.append(False)
logger.warning(
f"Matrix computation failed for factor matrix {factor_matrix_index} "
f"with {exception=}. Using previous factor matrix eigenvectors and continuing..."
Expand All @@ -1377,6 +1379,20 @@ def _amortized_computation(self) -> None:
)
factor_matrix_eigenvectors.copy_(computed_eigenvectors)

if all(success_tracker):
# Reset counter for failed amortized computations.
self._masked_failed_amortized_computation_counter_list[idx] = 0
else:
# Increment counter for failed amortized computations.
self._masked_failed_amortized_computation_counter_list[idx] += 1
# Only reuse previous eigenvectors if tolerance is not exceeded.
self._raise_exception_if_tolerance_exceeded(
self._masked_failed_amortized_computation_counter_list[idx],
ValueError(
f"Exceeded tolerance for number of failed eigenvector computations for {kronecker_factors.factor_matrix_indices}."
),
)

def dequantize_preconditioners(self) -> None:
with profiler.record_function(
f"## {self.__class__.__name__}:{self.dequantize_preconditioners.__name__} ##"
Expand Down

0 comments on commit 527c35e

Please sign in to comment.