Skip to content

Commit

Permalink
tensor_utils state dict tooling improvements, wip training progress m…
Browse files Browse the repository at this point in the history
…anager

- `muutils.tensor_utils` quality of life improvements;
  - improvements to `compare_state_dict`: unique errors for key, shape, and value discrepancies
  - added `get_dict_shapes` and `string_dict_shapes` for nicely printing dotlist-style dicts of tensors, useful for activation caches and model state dicts
- WIP training progress bar / manager features, not ready for anything
  • Loading branch information
mivanit authored Oct 4, 2023
2 parents 834c57a + 9e331d4 commit b4a203f
Show file tree
Hide file tree
Showing 2 changed files with 293 additions and 10 deletions.
190 changes: 190 additions & 0 deletions muutils/_wip/training_progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from pathlib import Path
from types import TracebackType
from typing import Literal, Mapping, Sequence, Type

from tqdm import tqdm

IntervalType = Literal["epochs", "batches", "samples"]


class TrainingProgress:
def __init__(
self,
# required
epochs: int,
dataloader: "torch.utils.data.DataLoader",
# saving:
model: "zanj.torchutil.ConfiguredModel|None" = None,
base_path: str | Path | None = None,
zanj: "zanj.Zanj|None" = None,
model_save_path_final: str | None = "model.final.zanj",
model_save_path_except: str | None = "model.exception.zanj",
# records
model_records_key: str = "training_records",
records_keys: Sequence[str] = ("loss", "accuracy", "grad_norm", "lr"),
records_save_path_external: str = "training_records.json",
# checkpoints
model_save_path_checkpoint: str = "checkpoints/checkpoint-{checkpoint}.zanj",
interval_checkpoint: str | tuple[int | None, IntervalType] = (None, "epochs"),
# progress bar config
fixed_width: int = 100,
pbar_disable: bool = False,
display_progress: IntervalType = "batches",
display_keys: dict[str, str] = {"loss": "L", "accuracy": "A"},
):
"""progress bar and training records saving context manager
given a number of epochs and dataloader, creates progress bars and stores training records
- you should call `update` periodically for batches, passing the index of the batch (in the epoch) and a dict of records
- you should call `next_epoch` at the end of each epoch
# Parameters:
- `epochs : int`
number of epochs
- `dataloader : torch.utils.data.DataLoader`
dataloader to get the number of batches and samples per epoch
- `model : zanj.torchutil.ConfiguredModel|None`
optional model to save records to (primarily for use with `zanj`)
(defaults to `None`)
- `model_save_path_except : str | None`
path to save the model to if an exception occurs
(defaults to `"model.exception.zanj"`)
- `model_records_key : str`
key to save records to in the model (if `model` is not `None`)
(defaults to `"training_records"` -- the default for a `zanj.torchutil.ConfiguredModel`)
- `zanj : zanj.Zanj|None`
optional configured `zanj` instance to save the model with, passed to `model.save()`
(defaults to `None`)
- `records_keys : Sequence[str]`
keys to save records for -- records are `dict[str, dict[int, float]]` where outer dict has record names as keys, and inner dict has overall batch index as keys and record values as values
(defaults to `("loss", "accuracy", "grad_norm", "lr")`)
- `fixed_width : int`
width of the progress bar
(defaults to `100`)
- `display_progress : Literal["epochs", "batches", "samples"]`
how to display samples -- "epochs" means one progress bar, others mean a progress bar for each epoch
(defaults to `"epochs"`)
- `display_keys : _type_`
record keys to display in the progress bar
(defaults to `{"loss": "L", "accuracy": "A"}`)
"""
# number of epochs, batches, and samples
self.epochs_total: int = epochs
self.batches_per_epoch: int = len(dataloader)
self.batches_total: int = self.batches_per_epoch * epochs
self.samples_per_epoch: int = len(dataloader.dataset)
self.samples_total: int = self.samples_per_epoch * epochs

# training records
self.records: Mapping[str, dict[int, float]] = {
key: dict() for key in records_keys
}
self.records_keys: Sequence[str] = records_keys
self.model: "zanj.torchutil.ConfiguredModel|None" = model
self.model_save_path_except: str | None = model_save_path_except
self.zanj: "zanj.Zanj|None" = zanj

# progress bar
self.fixed_width: int = fixed_width
self.current_epoch: int = 0
self.current_batch_in_epoch: int = 0
assert all(
key in records_keys for key in display_keys
), f"all display_keys must be in records_keys, an element in '{display_keys}' not in '{records_keys}'"
self.display_keys: Sequence[str] = display_keys
assert display_progress in (
"epochs",
"batches",
"samples",
), f"display_progress must be one of 'epochs', 'batches', or 'samples', got {display_progress}"
self.display_progress: Literal[
"epochs", "batches", "samples"
] = display_progress
self.pbar: tqdm

# single progress bar if over epochs, multiple otherwise
self.multibar: bool = not (self.display_progress == "epochs")
self.pbar_total: int = (
self.batches_per_epoch
if self.display_progress == "batches"
else self.samples_per_epoch
if self.display_progress == "samples"
else self.epochs_total
)

def _get_desc(self) -> str:
return f"Epoch {self.current_epoch+1}/{self.epochs_total}"

def __enter__(self):
self.pbar = tqdm(
total=self.pbar_total,
desc=self._get_desc(),
position=0,
leave=True,
ncols=self.fixed_width,
dynamic_ncols=False,
disable=self.pbar_disable,
)
return self

def __exit__(self, exc_type: Type, exc_val: Exception, exc_tb: TracebackType):
# if the model was passed
if self.model is not None:
# add records
if self.model_records_key is not None:
# add records to model
if getattr(self.model, self.model_records_key, None) is None:
# create a dict
setattr(self.model, self.model_records_key, self.records)
elif isinstance(getattr(self.model, self.model_records_key), dict):
# add to existing dict
getattr(self.model, self.model_records_key).update(self.records)
else:
# unexpected type
raise TypeError(
f"model_records_key '{self.model_records_key}' must be a dict, got {type(getattr(self.model, self.model_records_key))}"
)

# if error
if exc_type is not None:
# add exception info
getattr(self.model, self.model_records_key)["exception"] = dict(
type=str(exc_type),
value=str(exc_val),
traceback=str(exc_tb),
)
# save the model
self.model.save(self.model_save_path_except, zanj=self.zanj)

# close the progress bar
self.pbar.close()

def update(self, batch_in_epoch_idx: int, records: Mapping[str, float]) -> None:
"""update records and progress bar, call once a batch is complete"""
# update records
for key in records:
self.records[key][
self.current_epoch * self.batches_per_epoch + batch_in_epoch_idx
] = records[key]

self.pbar.set_postfix({key: records[key] for key in self.display_keys})
self.pbar.update(batch_in_epoch_idx - self.current_batch_in_epoch)
self.current_batch_in_epoch = batch_in_epoch_idx

def next_epoch(self) -> None:
"""call at the end of each epoch"""
self.current_epoch += 1
self.current_batch_in_epoch = 0
if self.multibar:
self.pbar.close()
self.pbar = tqdm(
total=self.pbar_total,
desc=self._get_desc(),
position=0,
leave=True,
ncols=self.fixed_width,
dynamic_ncols=False,
disable=self.pbar_disable,
)
else:
self.pbar.set_description(self._get_desc())
113 changes: 103 additions & 10 deletions muutils/tensor_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
import typing
import warnings

import jaxtyping
import numpy as np
import torch

from muutils.dictmagic import dotlist_to_nested_dict

# pylint: disable=missing-class-docstring


Expand Down Expand Up @@ -344,13 +347,103 @@ def rpad_array(
return pad_array(array, pad_length, pad_value, rpad=True)


def compare_state_dicts(d1: dict, d2: dict):
assert d1.keys() == d2.keys(), "state dict keys don't match!"
keys_failed: list[str] = list()
for k, v in d1.items():
v_load = d2[k]
if not (v == v_load).all():
keys_failed.append(k)
assert (
len(keys_failed) == 0
), f"{len(keys_failed)} / {len(d1)} state dict elements don't match: {keys_failed}"
def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
"""given a state dict or cache dict, compute the shapes and put them in a nested dict"""
return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})


def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
"""printable version of get_dict_shapes"""
return json.dumps(
dotlist_to_nested_dict(
{
k: str(
tuple(v.shape)
) # to string, since indent wont play nice with tuples
for k, v in d.items()
}
),
indent=2,
)


class StateDictCompareError(AssertionError):
"""raised when state dicts don't match"""

pass


class StateDictKeysError(StateDictCompareError):
"""raised when state dict keys don't match"""

pass


class StateDictShapeError(StateDictCompareError):
"""raised when state dict shapes don't match"""

pass


class StateDictValueError(StateDictCompareError):
"""raised when state dict values don't match"""

pass


def compare_state_dicts(
d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True
) -> None:
# check keys match
d1_keys: set = set(d1.keys())
d2_keys: set = set(d2.keys())
symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys)
keys_diff_1: set = d1_keys - d2_keys
keys_diff_2: set = d2_keys - d1_keys
# sort sets for easier debugging
symmetric_diff = set(sorted(symmetric_diff))
keys_diff_1 = set(sorted(keys_diff_1))
keys_diff_2 = set(sorted(keys_diff_2))
diff_shapes_1: str = (
string_dict_shapes({k: d1[k] for k in keys_diff_1})
if verbose
else "(verbose = False)"
)
diff_shapes_2: str = (
string_dict_shapes({k: d2[k] for k in keys_diff_2})
if verbose
else "(verbose = False)"
)
if not len(symmetric_diff) == 0:
raise StateDictKeysError(
f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
)

# check tensors match
shape_failed: list[str] = list()
vals_failed: list[str] = list()
for k, v1 in d1.items():
v2 = d2[k]
# check shapes first
if not v1.shape == v2.shape:
shape_failed.append(k)
else:
# if shapes match, check values
if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
vals_failed.append(k)

str_shape_failed: str = (
string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
)
str_vals_failed: str = (
string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
)

if not len(shape_failed) == 0:
raise StateDictShapeError(
f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
)
if not len(vals_failed) == 0:
raise StateDictValueError(
f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
)

0 comments on commit b4a203f

Please sign in to comment.