Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Converts checkpointer and state builder to configurable. #751

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions axlearn/common/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,10 @@
Required,
config_class,
config_for_function,
maybe_instantiate,
)
from axlearn.common.metrics import WeightedScalar
from axlearn.common.module import (
InvocationContext,
Module,
clone_context_stack,
install_context_stack,
)
from axlearn.common.module import InvocationContext, clone_context_stack, install_context_stack
from axlearn.common.summary_writer import CheckpointerAction, SummaryWriter
from axlearn.common.utils import (
Nested,
Expand Down Expand Up @@ -630,7 +626,7 @@ def fn(*, step: int, evaler_summaries: dict[str, Any]) -> bool:
return fn


class BaseCheckpointer(Module):
class BaseCheckpointer(Configurable):
"""A base checkpointer interface.

Checkpointers are required to implement `save`, `restore`, `stop`, and `checkpoint_paths`.
Expand All @@ -641,7 +637,7 @@ class BaseCheckpointer(Module):
"""

@config_class
class Config(Module.Config):
class Config(Configurable.Config):
"""Configures BaseCheckpointer.

Attributes:
Expand Down Expand Up @@ -680,8 +676,8 @@ def latest_checkpoint_path(cls, base_dir: str) -> str:
# Note: checkpoint_paths should already filter incomplete checkpoints.
return sorted(cls.checkpoint_paths(base_dir)).pop()

def __init__(self, cfg: Module.Config, *, parent: Optional[Module]):
super().__init__(cfg, parent=parent)
def __init__(self, cfg: Configurable.Config):
super().__init__(cfg)
self._within_context = False

def __enter__(self):
Expand Down Expand Up @@ -824,8 +820,8 @@ def cleanup_checkpoint(cls, ckpt_dir: str, *, sync: bool = True):
# Wait for cleanup to complete.
multihost_utils.sync_global_devices(f"{ckpt_dir}_cleanup")

def __init__(self, cfg: Config, *, parent: Optional[Module]):
super().__init__(cfg, parent=parent)
def __init__(self, cfg: Config):
super().__init__(cfg)
cfg: Checkpointer.Config = self.config

self._storage: StateStorage = cfg.storage.instantiate()
Expand All @@ -834,7 +830,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
self._save_policy: CheckpointPolicy = cfg.save_policy.instantiate()
if cfg.summary_writer is not None:
cfg.summary_writer.dir = cfg.summary_writer.dir or cfg.dir
self._add_child("summary_writer", cfg.summary_writer)
self.summary_writer: Optional[SummaryWriter] = maybe_instantiate(cfg.summary_writer)

def __enter__(self):
super().__enter__()
Expand All @@ -845,7 +841,7 @@ def _start_gc_thread(self):
if self._gc_thread is None and jax.process_index() == 0:
self._gc_stopping = threading.Event()
self._gc_thread = threading.Thread(
name=f"{self.path()}.gc_loop",
name=f"{self.__class__.__name__}.gc_loop",
target=self._gc_loop,
kwargs=dict(context_stack=clone_context_stack()),
)
Expand Down Expand Up @@ -894,7 +890,7 @@ def save(
self._storage.save_to_dir(
step=step, state=state, ckpt_dir=ckpt_dir, on_commit_callback=write_index_file
)
if "summary_writer" in self.children:
if self.summary_writer is not None:
self.summary_writer.log_checkpoint(
step=step,
state=state,
Expand Down Expand Up @@ -1009,7 +1005,7 @@ def validate_and_restore(*, step: int, ckpt_dir: str):
step=step, state=state, ckpt_dir=ckpt_dir
)
logging.info("Restored state from ckpt at step %s", step)
if "summary_writer" in self.children:
if self.summary_writer is not None:
self.summary_writer.log_checkpoint(
step=step,
state=state,
Expand Down
5 changes: 2 additions & 3 deletions axlearn/common/checkpointer_orbax.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
restore_tf_savables,
)
from axlearn.common.config import config_class
from axlearn.common.module import Module
from axlearn.common.utils import Nested, Tensor, TensorSpec


Expand Down Expand Up @@ -123,8 +122,8 @@ def checkpoint_paths(cls, base_dir: str) -> List[str]:
"""See `BaseCheckpointer.checkpointer_paths`."""
return [str(path) for path in ocp.utils.checkpoint_steps_paths(base_dir)]

def __init__(self, cfg: Config, *, parent: Optional[Module]):
super().__init__(cfg, parent=parent)
def __init__(self, cfg: Config):
super().__init__(cfg)

cfg: OrbaxCheckpointer.Config = self.config
save_policy = cfg.save_policy.instantiate()
Expand Down
6 changes: 1 addition & 5 deletions axlearn/common/checkpointer_orbax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,7 @@ def test_index(self):
if not test_utils.is_supported_mesh_shape(mesh_shape):
return
with _mesh(mesh_shape), tempfile.TemporaryDirectory() as temp_dir:
ckpt = (
OrbaxCheckpointer.default_config()
.set(name="test", dir=temp_dir)
.instantiate(parent=None)
)
ckpt = OrbaxCheckpointer.default_config().set(dir=temp_dir).instantiate()
step = 123
state = dict(x=jnp.ones([3, 2]))
ckpt.save(step=step, state=state)
Expand Down
44 changes: 23 additions & 21 deletions axlearn/common/checkpointer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def _checkpointer_config(
) -> BaseCheckpointer.Config:
# TODO(markblee): Use context manager instead of mkdtemp.
return checkpointer_cls.default_config().set(
name="test",
dir=tempfile.mkdtemp(),
keep_last_n=1,
)
Expand All @@ -78,7 +77,7 @@ def test_save_and_restore(self, checkpointer_cls: Type[BaseCheckpointer]):
with _mesh(mesh_shape):
cfg = _checkpointer_config(checkpointer_cls)
cfg.save_policy.min_step = 0
ckpt: BaseCheckpointer = cfg.instantiate(parent=None)
ckpt: BaseCheckpointer = cfg.instantiate()
state0 = dict(x=jnp.zeros([], dtype=jnp.int32), y=jnp.ones([2], dtype=jnp.float32))
state1 = dict(x=jnp.ones([], dtype=jnp.int32), y=jnp.ones([2], dtype=jnp.float32) + 1)

Expand Down Expand Up @@ -155,7 +154,7 @@ def test_save_and_restore_mesh(self, checkpointer_cls: Type[BaseCheckpointer]):
return

cfg = _checkpointer_config(checkpointer_cls)
ckpt: BaseCheckpointer = cfg.instantiate(parent=None)
ckpt: BaseCheckpointer = cfg.instantiate()
state = dict(
x=jax.random.uniform(jax.random.PRNGKey(123), shape=[8, 4], dtype=jnp.float32),
)
Expand Down Expand Up @@ -203,7 +202,7 @@ def test_save_and_restore_latest_valid(self, checkpointer_cls: Type[BaseCheckpoi
return
with _mesh(mesh_shape):
cfg = _checkpointer_config(checkpointer_cls)
ckpt: BaseCheckpointer = cfg.instantiate(parent=None)
ckpt: BaseCheckpointer = cfg.instantiate()
state0 = dict(x=jnp.zeros([], dtype=jnp.int32), y=jnp.ones([2], dtype=jnp.float32))

# Restoring from an empty dir returns the input state if step=None.
Expand Down Expand Up @@ -251,7 +250,7 @@ def test_gda(self, checkpointer_cls, mesh_shape):
return
with _mesh(mesh_shape):
cfg = _checkpointer_config(checkpointer_cls)
ckpt: Checkpointer = cfg.instantiate(parent=None)
ckpt: Checkpointer = cfg.instantiate()
state = dict(x=jnp.arange(16).reshape((4, 4)))
ckpt.save(step=10, state=state)
ckpt.wait_until_finished()
Expand Down Expand Up @@ -282,7 +281,7 @@ def test_custom_dict(self, checkpointer_cls, custom_dict_type):
return
with _mesh(mesh_shape):
cfg = _checkpointer_config(checkpointer_cls)
ckpt: Checkpointer = cfg.instantiate(parent=None)
ckpt: Checkpointer = cfg.instantiate()
state0 = custom_dict_type(
x=jnp.zeros([], dtype=jnp.int32), y=jnp.ones([2], dtype=jnp.float32)
)
Expand All @@ -307,7 +306,7 @@ def test_input_iterator(self, checkpointer_cls):
return
with _mesh(mesh_shape):
cfg = _checkpointer_config(checkpointer_cls)
ckpt: Checkpointer = cfg.instantiate(parent=None)
ckpt: Checkpointer = cfg.instantiate()
input_iter = iter(tf.data.Dataset.from_tensor_slices([1, 2, 3]))
# Move the input_iter.
self.assertEqual(next(input_iter), 1)
Expand Down Expand Up @@ -409,11 +408,14 @@ def patch_tf_io_behavior(*args):
return [x + "/" for x in out if not x.endswith("/")]

# pylint: disable=line-too-long
with _mesh(mesh_shape), mock.patch(
"tensorflow.io.gfile.listdir", patch_tf_io_behavior
) if listdir_add_trailing_slash else nullcontext(), tempfile.TemporaryDirectory() as temp_dir:
with (
_mesh(mesh_shape),
mock.patch("tensorflow.io.gfile.listdir", patch_tf_io_behavior)
if listdir_add_trailing_slash
else nullcontext(),
tempfile.TemporaryDirectory() as temp_dir,
):
cfg = Checkpointer.default_config().set(
name="test",
dir=temp_dir,
keep_last_n=3,
keep_every_n_steps=2,
Expand All @@ -422,10 +424,10 @@ def patch_tf_io_behavior(*args):
cfg.save_policy.min_step = 0

# Running gc for non-existent dir shouldn't fail.
ckpt_fake = cfg.clone(dir=os.path.join(temp_dir, "fake_dir")).instantiate(parent=None)
ckpt_fake = cfg.clone(dir=os.path.join(temp_dir, "fake_dir")).instantiate()
ckpt_fake._run_garbage_collection()

ckpt: Checkpointer = cfg.instantiate(parent=None)
ckpt: Checkpointer = cfg.instantiate()
state = dict(x=jnp.zeros([], dtype=jnp.int32))

for step in range(10):
Expand Down Expand Up @@ -532,7 +534,7 @@ def test_stop(self):
if not test_utils.is_supported_mesh_shape(mesh_shape):
return
cfg = _checkpointer_config()
ckpt: Checkpointer = cfg.instantiate(parent=None)
ckpt: Checkpointer = cfg.instantiate()
# GC thread is not started until the start_gc_thread() call.
self.assertIsNone(ckpt._gc_thread)

Expand All @@ -550,7 +552,7 @@ def test_stop(self):

@parameterized.parameters([Checkpointer, OrbaxCheckpointer])
def test_context(self, checkpointer_cls):
ckpt = _checkpointer_config(checkpointer_cls).instantiate(parent=None)
ckpt = _checkpointer_config(checkpointer_cls).instantiate()

if checkpointer_cls is Checkpointer:
with ckpt:
Expand All @@ -568,7 +570,7 @@ def test_context(self, checkpointer_cls):

def test_stop_on_exception(self):
# Ensure that checkpointer gc thread terminates if there's an exception.
ckpt = _checkpointer_config().instantiate(parent=None)
ckpt = _checkpointer_config().instantiate()

def run():
ckpt._start_gc_thread()
Expand Down Expand Up @@ -598,7 +600,7 @@ def test_summary_writer_checkpoint(self):
with _mesh(mesh_shape):
cfg = _checkpointer_config()
cfg.summary_writer = SummaryWriter.default_config()
ckpt: Checkpointer = cfg.instantiate(parent=None)
ckpt: Checkpointer = cfg.instantiate()
self.assertIsNotNone(ckpt.summary_writer)

ckpt.summary_writer.log_checkpoint = mock.Mock()
Expand Down Expand Up @@ -634,7 +636,7 @@ def _create_metric(value):
metric=EvalMetric(evaler_name="evaler", metric_name="metric"), mode=mode
)
)
ckpt: Checkpointer = cfg.instantiate(parent=None)
ckpt: Checkpointer = cfg.instantiate()
state0 = dict(x=jnp.zeros([], dtype=jnp.int32))
state2 = dict(x=jnp.ones([], dtype=jnp.int32) * 2)
state4 = dict(x=jnp.ones([], dtype=jnp.int32) * 4)
Expand Down Expand Up @@ -677,7 +679,7 @@ def test_best_metric_policy_value_error(self, checkpointer_cls):
metric=EvalMetric(evaler_name="evaler", metric_name="metric"), mode="max"
)
)
ckpt: Checkpointer = cfg.instantiate(parent=None)
ckpt: Checkpointer = cfg.instantiate()
state0 = dict(x=jnp.zeros([], dtype=jnp.int32))

with pytest.raises(ValueError, match=re.escape("evaler_summaries is empty")):
Expand Down Expand Up @@ -766,7 +768,7 @@ def test_read_state_spec(self, checkpointer_cls: Type[BaseCheckpointer]):
with _mesh(mesh_shape):
cfg = _checkpointer_config(checkpointer_cls)
cfg.save_policy.min_step = 0
ckpt: BaseCheckpointer = cfg.instantiate(parent=None)
ckpt: BaseCheckpointer = cfg.instantiate()
state0 = dict(
**{
f"v_{str(dtype.dtype)}": jnp.zeros([], dtype=dtype)
Expand Down Expand Up @@ -826,7 +828,7 @@ def tree_unflatten(cls, keys, values):
with unittest.mock.patch.dict(globals(), {"SWITCHABLE_VDICT_IMPL": OldVDict}):
cfg = _checkpointer_config()
cfg.save_policy.min_step = 0
ckpt: Checkpointer = cfg.instantiate(parent=None)
ckpt: Checkpointer = cfg.instantiate()
# VDict with out of order keys.
state0 = dict(a=3, b=SwitchableVDict(d=6, b=5))
state0 = jax.tree.map(jnp.asarray, state0)
Expand Down
15 changes: 8 additions & 7 deletions axlearn/common/evaler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.

"""Evaler and base metric calculators."""

import functools
import graphlib
import os.path
Expand All @@ -15,19 +16,21 @@
from jax import numpy as jnp
from jax.experimental.pjit import pjit

from axlearn.common import struct, summary_writer, utils
from axlearn.common import struct, utils
from axlearn.common.base_model import BaseModel
from axlearn.common.config import (
REQUIRED,
InstantiableConfig,
Required,
config_class,
config_for_function,
maybe_instantiate,
)
from axlearn.common.inference_output import BaseOutputWriter
from axlearn.common.metrics import MetricAccumulator, WeightedScalar
from axlearn.common.module import Module, OutputCollection
from axlearn.common.module import functional as F
from axlearn.common.summary_writer import BaseWriter, SummaryWriter
from axlearn.common.utils import (
NestedPartitionSpec,
NestedTensor,
Expand Down Expand Up @@ -558,7 +561,7 @@ class Config(Module.Config):
# The input source.
input: Required[InstantiableConfig] = REQUIRED
# A summary writer to log tagged summary values.
summary_writer: InstantiableConfig = summary_writer.SummaryWriter.default_config()
summary_writer: BaseWriter.Config = SummaryWriter.default_config()
# Run this evaler according to this policy.
eval_policy: InstantiableConfig = config_for_function(every_n_steps_policy)
# Which evaluation iters to trace with the profiler each time the evaler is run.
Expand Down Expand Up @@ -597,10 +600,8 @@ def __init__(
model=model,
model_param_partition_specs=model_param_partition_specs,
)
self._add_child("summary_writer", cfg.summary_writer)
if cfg.output_writer is not None:
self._add_child("output_writer", cfg.output_writer)

self.output_writer: Optional[BaseOutputWriter] = maybe_instantiate(cfg.output_writer)
self.summary_writer: BaseWriter = cfg.summary_writer.instantiate()
self._trace_steps = set()
self._eval_policy: EvalPolicy = cfg.eval_policy.instantiate()

Expand Down Expand Up @@ -696,7 +697,7 @@ def eval_step(
)
metric_calculator_state = forward_outputs["state"]
all_metric_calculator_outputs.append(forward_outputs["output"])
if "output_writer" in self.children:
if self.output_writer is not None:
self.output_writer.write(
input_batch=global_input_batch, output_batch=forward_outputs["output"]
)
Expand Down
4 changes: 1 addition & 3 deletions axlearn/common/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
lambda spec: spec.mesh_axes, self._inference_runner_state_specs
)
logging.info("Building ckpt state from %s", cfg.init_state_builder.klass.__name__)
builder = cfg.init_state_builder.set(
name="init_state_builder",
).instantiate(parent=None)
builder: Builder = cfg.init_state_builder.instantiate()

# Check that builder should expect tensor specs.
if builder.input_state_type() != Builder.StateType.TENSOR_SPECS:
Expand Down
Loading