Skip to content

Commit

Permalink
better errors for state dict comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
mivanit committed Oct 2, 2023
1 parent 9a8098b commit 40c0795
Showing 1 changed file with 36 additions and 9 deletions.
45 changes: 36 additions & 9 deletions muutils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,30 @@ def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
)


class StateDictCompareError(AssertionError):
"""raised when state dicts don't match"""

pass


class StateDictKeysError(StateDictCompareError):
"""raised when state dict keys don't match"""

pass


class StateDictShapeError(StateDictCompareError):
"""raised when state dict shapes don't match"""

pass


class StateDictValueError(StateDictCompareError):
"""raised when state dict values don't match"""

pass


def compare_state_dicts(
d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True
) -> None:
Expand All @@ -390,9 +414,10 @@ def compare_state_dicts(
if verbose
else "(verbose = False)"
)
assert (
len(symmetric_diff) == 0
), f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
if not len(symmetric_diff) == 0:
raise StateDictKeysError(
f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
)

# check tensors match
shape_failed: list[str] = list()
Expand All @@ -414,9 +439,11 @@ def compare_state_dicts(
string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
)

assert (
len(shape_failed) == 0
), f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
assert (
len(vals_failed) == 0
), f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
if not len(shape_failed) == 0:
raise StateDictShapeError(
f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
)
if not len(vals_failed) == 0:
raise StateDictValueError(
f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
)

0 comments on commit 40c0795

Please sign in to comment.