From 2e50095780fdb31003d7a2dcc71d45f78dbe2377 Mon Sep 17 00:00:00 2001 From: mark Date: Tue, 8 Oct 2024 15:32:34 -0700 Subject: [PATCH] Converts checkpointer and state builder to configurable. --- axlearn/common/checkpointer.py | 28 ++++---- axlearn/common/checkpointer_orbax.py | 5 +- axlearn/common/checkpointer_orbax_test.py | 6 +- axlearn/common/checkpointer_test.py | 44 ++++++------ axlearn/common/evaler.py | 15 +++-- axlearn/common/inference.py | 4 +- axlearn/common/inference_output.py | 44 ++++-------- axlearn/common/inference_pipeline.py | 11 +-- axlearn/common/inference_test.py | 12 ++-- axlearn/common/input_tf_data_test.py | 7 +- axlearn/common/optimizers_test.py | 5 +- axlearn/common/state_builder.py | 45 ++++++------- axlearn/common/state_builder_test.py | 82 +++++++++-------------- axlearn/common/summary_test.py | 8 +-- axlearn/common/summary_writer.py | 36 +++++----- axlearn/common/summary_writer_test.py | 27 ++++---- axlearn/common/trainer.py | 13 ++-- axlearn/common/trainer_test.py | 18 +++-- 18 files changed, 184 insertions(+), 226 deletions(-) diff --git a/axlearn/common/checkpointer.py b/axlearn/common/checkpointer.py index 778b11f5..fa121829 100644 --- a/axlearn/common/checkpointer.py +++ b/axlearn/common/checkpointer.py @@ -35,14 +35,10 @@ Required, config_class, config_for_function, + maybe_instantiate, ) from axlearn.common.metrics import WeightedScalar -from axlearn.common.module import ( - InvocationContext, - Module, - clone_context_stack, - install_context_stack, -) +from axlearn.common.module import InvocationContext, clone_context_stack, install_context_stack from axlearn.common.summary_writer import CheckpointerAction, SummaryWriter from axlearn.common.utils import ( Nested, @@ -630,7 +626,7 @@ def fn(*, step: int, evaler_summaries: dict[str, Any]) -> bool: return fn -class BaseCheckpointer(Module): +class BaseCheckpointer(Configurable): """A base checkpointer interface. Checkpointers are required to implement `save`, `restore`, `stop`, and `checkpoint_paths`. @@ -641,7 +637,7 @@ class BaseCheckpointer(Module): """ @config_class - class Config(Module.Config): + class Config(Configurable.Config): """Configures BaseCheckpointer. Attributes: @@ -680,8 +676,8 @@ def latest_checkpoint_path(cls, base_dir: str) -> str: # Note: checkpoint_paths should already filter incomplete checkpoints. return sorted(cls.checkpoint_paths(base_dir)).pop() - def __init__(self, cfg: Module.Config, *, parent: Optional[Module]): - super().__init__(cfg, parent=parent) + def __init__(self, cfg: Configurable.Config): + super().__init__(cfg) self._within_context = False def __enter__(self): @@ -824,8 +820,8 @@ def cleanup_checkpoint(cls, ckpt_dir: str, *, sync: bool = True): # Wait for cleanup to complete. multihost_utils.sync_global_devices(f"{ckpt_dir}_cleanup") - def __init__(self, cfg: Config, *, parent: Optional[Module]): - super().__init__(cfg, parent=parent) + def __init__(self, cfg: Config): + super().__init__(cfg) cfg: Checkpointer.Config = self.config self._storage: StateStorage = cfg.storage.instantiate() @@ -834,7 +830,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]): self._save_policy: CheckpointPolicy = cfg.save_policy.instantiate() if cfg.summary_writer is not None: cfg.summary_writer.dir = cfg.summary_writer.dir or cfg.dir - self._add_child("summary_writer", cfg.summary_writer) + self.summary_writer: Optional[SummaryWriter] = maybe_instantiate(cfg.summary_writer) def __enter__(self): super().__enter__() @@ -845,7 +841,7 @@ def _start_gc_thread(self): if self._gc_thread is None and jax.process_index() == 0: self._gc_stopping = threading.Event() self._gc_thread = threading.Thread( - name=f"{self.path()}.gc_loop", + name=f"{self.__class__.__name__}.gc_loop", target=self._gc_loop, kwargs=dict(context_stack=clone_context_stack()), ) @@ -894,7 +890,7 @@ def save( self._storage.save_to_dir( step=step, state=state, ckpt_dir=ckpt_dir, on_commit_callback=write_index_file ) - if "summary_writer" in self.children: + if self.summary_writer is not None: self.summary_writer.log_checkpoint( step=step, state=state, @@ -1009,7 +1005,7 @@ def validate_and_restore(*, step: int, ckpt_dir: str): step=step, state=state, ckpt_dir=ckpt_dir ) logging.info("Restored state from ckpt at step %s", step) - if "summary_writer" in self.children: + if self.summary_writer is not None: self.summary_writer.log_checkpoint( step=step, state=state, diff --git a/axlearn/common/checkpointer_orbax.py b/axlearn/common/checkpointer_orbax.py index b6e52dc1..526ad8ea 100644 --- a/axlearn/common/checkpointer_orbax.py +++ b/axlearn/common/checkpointer_orbax.py @@ -28,7 +28,6 @@ restore_tf_savables, ) from axlearn.common.config import config_class -from axlearn.common.module import Module from axlearn.common.utils import Nested, Tensor, TensorSpec @@ -123,8 +122,8 @@ def checkpoint_paths(cls, base_dir: str) -> List[str]: """See `BaseCheckpointer.checkpointer_paths`.""" return [str(path) for path in ocp.utils.checkpoint_steps_paths(base_dir)] - def __init__(self, cfg: Config, *, parent: Optional[Module]): - super().__init__(cfg, parent=parent) + def __init__(self, cfg: Config): + super().__init__(cfg) cfg: OrbaxCheckpointer.Config = self.config save_policy = cfg.save_policy.instantiate() diff --git a/axlearn/common/checkpointer_orbax_test.py b/axlearn/common/checkpointer_orbax_test.py index edc8e1fd..c8ad2883 100644 --- a/axlearn/common/checkpointer_orbax_test.py +++ b/axlearn/common/checkpointer_orbax_test.py @@ -31,11 +31,7 @@ def test_index(self): if not test_utils.is_supported_mesh_shape(mesh_shape): return with _mesh(mesh_shape), tempfile.TemporaryDirectory() as temp_dir: - ckpt = ( - OrbaxCheckpointer.default_config() - .set(name="test", dir=temp_dir) - .instantiate(parent=None) - ) + ckpt = OrbaxCheckpointer.default_config().set(dir=temp_dir).instantiate() step = 123 state = dict(x=jnp.ones([3, 2])) ckpt.save(step=step, state=state) diff --git a/axlearn/common/checkpointer_test.py b/axlearn/common/checkpointer_test.py index f398c070..52425094 100644 --- a/axlearn/common/checkpointer_test.py +++ b/axlearn/common/checkpointer_test.py @@ -63,7 +63,6 @@ def _checkpointer_config( ) -> BaseCheckpointer.Config: # TODO(markblee): Use context manager instead of mkdtemp. return checkpointer_cls.default_config().set( - name="test", dir=tempfile.mkdtemp(), keep_last_n=1, ) @@ -78,7 +77,7 @@ def test_save_and_restore(self, checkpointer_cls: Type[BaseCheckpointer]): with _mesh(mesh_shape): cfg = _checkpointer_config(checkpointer_cls) cfg.save_policy.min_step = 0 - ckpt: BaseCheckpointer = cfg.instantiate(parent=None) + ckpt: BaseCheckpointer = cfg.instantiate() state0 = dict(x=jnp.zeros([], dtype=jnp.int32), y=jnp.ones([2], dtype=jnp.float32)) state1 = dict(x=jnp.ones([], dtype=jnp.int32), y=jnp.ones([2], dtype=jnp.float32) + 1) @@ -155,7 +154,7 @@ def test_save_and_restore_mesh(self, checkpointer_cls: Type[BaseCheckpointer]): return cfg = _checkpointer_config(checkpointer_cls) - ckpt: BaseCheckpointer = cfg.instantiate(parent=None) + ckpt: BaseCheckpointer = cfg.instantiate() state = dict( x=jax.random.uniform(jax.random.PRNGKey(123), shape=[8, 4], dtype=jnp.float32), ) @@ -203,7 +202,7 @@ def test_save_and_restore_latest_valid(self, checkpointer_cls: Type[BaseCheckpoi return with _mesh(mesh_shape): cfg = _checkpointer_config(checkpointer_cls) - ckpt: BaseCheckpointer = cfg.instantiate(parent=None) + ckpt: BaseCheckpointer = cfg.instantiate() state0 = dict(x=jnp.zeros([], dtype=jnp.int32), y=jnp.ones([2], dtype=jnp.float32)) # Restoring from an empty dir returns the input state if step=None. @@ -251,7 +250,7 @@ def test_gda(self, checkpointer_cls, mesh_shape): return with _mesh(mesh_shape): cfg = _checkpointer_config(checkpointer_cls) - ckpt: Checkpointer = cfg.instantiate(parent=None) + ckpt: Checkpointer = cfg.instantiate() state = dict(x=jnp.arange(16).reshape((4, 4))) ckpt.save(step=10, state=state) ckpt.wait_until_finished() @@ -282,7 +281,7 @@ def test_custom_dict(self, checkpointer_cls, custom_dict_type): return with _mesh(mesh_shape): cfg = _checkpointer_config(checkpointer_cls) - ckpt: Checkpointer = cfg.instantiate(parent=None) + ckpt: Checkpointer = cfg.instantiate() state0 = custom_dict_type( x=jnp.zeros([], dtype=jnp.int32), y=jnp.ones([2], dtype=jnp.float32) ) @@ -307,7 +306,7 @@ def test_input_iterator(self, checkpointer_cls): return with _mesh(mesh_shape): cfg = _checkpointer_config(checkpointer_cls) - ckpt: Checkpointer = cfg.instantiate(parent=None) + ckpt: Checkpointer = cfg.instantiate() input_iter = iter(tf.data.Dataset.from_tensor_slices([1, 2, 3])) # Move the input_iter. self.assertEqual(next(input_iter), 1) @@ -409,11 +408,14 @@ def patch_tf_io_behavior(*args): return [x + "/" for x in out if not x.endswith("/")] # pylint: disable=line-too-long - with _mesh(mesh_shape), mock.patch( - "tensorflow.io.gfile.listdir", patch_tf_io_behavior - ) if listdir_add_trailing_slash else nullcontext(), tempfile.TemporaryDirectory() as temp_dir: + with ( + _mesh(mesh_shape), + mock.patch("tensorflow.io.gfile.listdir", patch_tf_io_behavior) + if listdir_add_trailing_slash + else nullcontext(), + tempfile.TemporaryDirectory() as temp_dir, + ): cfg = Checkpointer.default_config().set( - name="test", dir=temp_dir, keep_last_n=3, keep_every_n_steps=2, @@ -422,10 +424,10 @@ def patch_tf_io_behavior(*args): cfg.save_policy.min_step = 0 # Running gc for non-existent dir shouldn't fail. - ckpt_fake = cfg.clone(dir=os.path.join(temp_dir, "fake_dir")).instantiate(parent=None) + ckpt_fake = cfg.clone(dir=os.path.join(temp_dir, "fake_dir")).instantiate() ckpt_fake._run_garbage_collection() - ckpt: Checkpointer = cfg.instantiate(parent=None) + ckpt: Checkpointer = cfg.instantiate() state = dict(x=jnp.zeros([], dtype=jnp.int32)) for step in range(10): @@ -532,7 +534,7 @@ def test_stop(self): if not test_utils.is_supported_mesh_shape(mesh_shape): return cfg = _checkpointer_config() - ckpt: Checkpointer = cfg.instantiate(parent=None) + ckpt: Checkpointer = cfg.instantiate() # GC thread is not started until the start_gc_thread() call. self.assertIsNone(ckpt._gc_thread) @@ -550,7 +552,7 @@ def test_stop(self): @parameterized.parameters([Checkpointer, OrbaxCheckpointer]) def test_context(self, checkpointer_cls): - ckpt = _checkpointer_config(checkpointer_cls).instantiate(parent=None) + ckpt = _checkpointer_config(checkpointer_cls).instantiate() if checkpointer_cls is Checkpointer: with ckpt: @@ -568,7 +570,7 @@ def test_context(self, checkpointer_cls): def test_stop_on_exception(self): # Ensure that checkpointer gc thread terminates if there's an exception. - ckpt = _checkpointer_config().instantiate(parent=None) + ckpt = _checkpointer_config().instantiate() def run(): ckpt._start_gc_thread() @@ -598,7 +600,7 @@ def test_summary_writer_checkpoint(self): with _mesh(mesh_shape): cfg = _checkpointer_config() cfg.summary_writer = SummaryWriter.default_config() - ckpt: Checkpointer = cfg.instantiate(parent=None) + ckpt: Checkpointer = cfg.instantiate() self.assertIsNotNone(ckpt.summary_writer) ckpt.summary_writer.log_checkpoint = mock.Mock() @@ -634,7 +636,7 @@ def _create_metric(value): metric=EvalMetric(evaler_name="evaler", metric_name="metric"), mode=mode ) ) - ckpt: Checkpointer = cfg.instantiate(parent=None) + ckpt: Checkpointer = cfg.instantiate() state0 = dict(x=jnp.zeros([], dtype=jnp.int32)) state2 = dict(x=jnp.ones([], dtype=jnp.int32) * 2) state4 = dict(x=jnp.ones([], dtype=jnp.int32) * 4) @@ -677,7 +679,7 @@ def test_best_metric_policy_value_error(self, checkpointer_cls): metric=EvalMetric(evaler_name="evaler", metric_name="metric"), mode="max" ) ) - ckpt: Checkpointer = cfg.instantiate(parent=None) + ckpt: Checkpointer = cfg.instantiate() state0 = dict(x=jnp.zeros([], dtype=jnp.int32)) with pytest.raises(ValueError, match=re.escape("evaler_summaries is empty")): @@ -766,7 +768,7 @@ def test_read_state_spec(self, checkpointer_cls: Type[BaseCheckpointer]): with _mesh(mesh_shape): cfg = _checkpointer_config(checkpointer_cls) cfg.save_policy.min_step = 0 - ckpt: BaseCheckpointer = cfg.instantiate(parent=None) + ckpt: BaseCheckpointer = cfg.instantiate() state0 = dict( **{ f"v_{str(dtype.dtype)}": jnp.zeros([], dtype=dtype) @@ -826,7 +828,7 @@ def tree_unflatten(cls, keys, values): with unittest.mock.patch.dict(globals(), {"SWITCHABLE_VDICT_IMPL": OldVDict}): cfg = _checkpointer_config() cfg.save_policy.min_step = 0 - ckpt: Checkpointer = cfg.instantiate(parent=None) + ckpt: Checkpointer = cfg.instantiate() # VDict with out of order keys. state0 = dict(a=3, b=SwitchableVDict(d=6, b=5)) state0 = jax.tree.map(jnp.asarray, state0) diff --git a/axlearn/common/evaler.py b/axlearn/common/evaler.py index 7ef8fa41..a368e428 100644 --- a/axlearn/common/evaler.py +++ b/axlearn/common/evaler.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Evaler and base metric calculators.""" + import functools import graphlib import os.path @@ -15,7 +16,7 @@ from jax import numpy as jnp from jax.experimental.pjit import pjit -from axlearn.common import struct, summary_writer, utils +from axlearn.common import struct, utils from axlearn.common.base_model import BaseModel from axlearn.common.config import ( REQUIRED, @@ -23,11 +24,13 @@ Required, config_class, config_for_function, + maybe_instantiate, ) from axlearn.common.inference_output import BaseOutputWriter from axlearn.common.metrics import MetricAccumulator, WeightedScalar from axlearn.common.module import Module, OutputCollection from axlearn.common.module import functional as F +from axlearn.common.summary_writer import BaseWriter, SummaryWriter from axlearn.common.utils import ( NestedPartitionSpec, NestedTensor, @@ -558,7 +561,7 @@ class Config(Module.Config): # The input source. input: Required[InstantiableConfig] = REQUIRED # A summary writer to log tagged summary values. - summary_writer: InstantiableConfig = summary_writer.SummaryWriter.default_config() + summary_writer: BaseWriter.Config = SummaryWriter.default_config() # Run this evaler according to this policy. eval_policy: InstantiableConfig = config_for_function(every_n_steps_policy) # Which evaluation iters to trace with the profiler each time the evaler is run. @@ -597,10 +600,8 @@ def __init__( model=model, model_param_partition_specs=model_param_partition_specs, ) - self._add_child("summary_writer", cfg.summary_writer) - if cfg.output_writer is not None: - self._add_child("output_writer", cfg.output_writer) - + self.output_writer: Optional[BaseOutputWriter] = maybe_instantiate(cfg.output_writer) + self.summary_writer: BaseWriter = cfg.summary_writer.instantiate() self._trace_steps = set() self._eval_policy: EvalPolicy = cfg.eval_policy.instantiate() @@ -696,7 +697,7 @@ def eval_step( ) metric_calculator_state = forward_outputs["state"] all_metric_calculator_outputs.append(forward_outputs["output"]) - if "output_writer" in self.children: + if self.output_writer is not None: self.output_writer.write( input_batch=global_input_batch, output_batch=forward_outputs["output"] ) diff --git a/axlearn/common/inference.py b/axlearn/common/inference.py index 0b3ad390..2982457a 100644 --- a/axlearn/common/inference.py +++ b/axlearn/common/inference.py @@ -225,9 +225,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]): lambda spec: spec.mesh_axes, self._inference_runner_state_specs ) logging.info("Building ckpt state from %s", cfg.init_state_builder.klass.__name__) - builder = cfg.init_state_builder.set( - name="init_state_builder", - ).instantiate(parent=None) + builder: Builder = cfg.init_state_builder.instantiate() # Check that builder should expect tensor specs. if builder.input_state_type() != Builder.StateType.TENSOR_SPECS: diff --git a/axlearn/common/inference_output.py b/axlearn/common/inference_output.py index 1821662d..3fdddc70 100644 --- a/axlearn/common/inference_output.py +++ b/axlearn/common/inference_output.py @@ -4,7 +4,7 @@ import json import os.path -from typing import Optional, Union +from typing import Union import jax import numpy as np @@ -12,8 +12,7 @@ from jax import numpy as jnp from axlearn.common import file_system as fs -from axlearn.common.config import REQUIRED, Required, config_class -from axlearn.common.module import Module +from axlearn.common.config import REQUIRED, Configurable, Required, config_class from axlearn.common.utils import ( DataPartitionType, NestedTensor, @@ -23,11 +22,11 @@ ) -class BaseOutputWriter(Module): +class BaseOutputWriter(Configurable): """Base class for OutputWriter, which writes records for inference outputs.""" @config_class - class Config(Module.Config): + class Config(Configurable.Config): # How input and output batches are partitioned. batch_partition_spec: Required[DataPartitionType] = REQUIRED @@ -45,7 +44,7 @@ def flush(self): raise NotImplementedError(type(self)) -class BaseRecordSink(Module): +class BaseRecordSink(Configurable): def write(self, record: NestedTensor): """Writes `record` to the sink.""" raise NotImplementedError(type(self)) @@ -91,7 +90,7 @@ class TfExampleRecordSink(BaseRecordSink): """A sink that writes each example as a record to a TF record file.""" @config_class - class Config(Module.Config): + class Config(Configurable.Config): # The path should commonly contain substitution patterns for: # # - `data_dir`: The data directory name from `get_data_dir()` @@ -101,13 +100,8 @@ class Config(Module.Config): # E.g., output_path = "{data_dir}/out-records-{process_index:05d}-of-{process_count:05d}". output_path: Required[str] = REQUIRED - def __init__( - self, - cfg: Config, - *, - parent: Optional[Module], - ): - super().__init__(cfg, parent=parent) + def __init__(self, cfg: Config): + super().__init__(cfg) cfg = self.config output_path = cfg.output_path.format( data_dir=get_data_dir(), @@ -132,7 +126,7 @@ class JsonlExampleRecordSink(BaseRecordSink): """A sink that writes each example as a record to a JSON Lines file.""" @config_class - class Config(Module.Config): + class Config(Configurable.Config): # The path should commonly contain substitution patterns for: # # - `data_dir`: The data directory name from `get_data_dir()` @@ -143,13 +137,8 @@ class Config(Module.Config): # process_count:05d}.jsonl". output_path: Required[str] = REQUIRED - def __init__( - self, - cfg: Config, - *, - parent: Optional[Module], - ): - super().__init__(cfg, parent=parent) + def __init__(self, cfg: Config): + super().__init__(cfg) cfg = self.config output_path = cfg.output_path.format( data_dir=get_data_dir(), @@ -181,15 +170,10 @@ class OutputRecordWriter(BaseOutputWriter): class Config(BaseOutputWriter.Config): sink: BaseRecordSink.Config = TfExampleRecordSink.default_config() - def __init__( - self, - cfg: Config, - *, - parent: Optional[Module], - ): - super().__init__(cfg, parent=parent) + def __init__(self, cfg: Config): + super().__init__(cfg) cfg = self.config - self._add_child("sink", cfg.sink) + self.sink = cfg.sink.instantiate() def write(self, *, input_batch: NestedTensor, output_batch: NestedTensor): """Writes records extracted from the given input/output batch pair. diff --git a/axlearn/common/inference_pipeline.py b/axlearn/common/inference_pipeline.py index dba02972..bffd8fe3 100644 --- a/axlearn/common/inference_pipeline.py +++ b/axlearn/common/inference_pipeline.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """An inference pipeline consists of an input, a runner, and an output writer.""" + import time from typing import Optional @@ -105,11 +106,11 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]): cfg = self.config self._add_child("input", cfg.input) self._add_child("runner", cfg.runner) - self._add_child( - "output_writer", - cfg.output_writer.set(batch_partition_spec=cfg.runner.input_batch_partition_spec), - ) - self._add_child("summary_writer", cfg.summary_writer) + + self.output_writer = cfg.output_writer.set( + batch_partition_spec=cfg.runner.input_batch_partition_spec + ).instantiate() + self.summary_writer = cfg.summary_writer.instantiate() def run(self, **kwargs): cfg = self.config diff --git a/axlearn/common/inference_test.py b/axlearn/common/inference_test.py index d6d946ed..7d561759 100644 --- a/axlearn/common/inference_test.py +++ b/axlearn/common/inference_test.py @@ -226,7 +226,6 @@ def _runner_config( ) if use_ema: inference_runner_cfg.init_state_builder = RestoreAndConvertBuilder.default_config().set( - name="builder", builder=TensorStoreStateStorageBuilder.default_config().set( dir=ckpt_dir, validation=CheckpointValidationType.CONTAINS_STATE_UP_TO_DTYPE ), @@ -779,10 +778,13 @@ def test_pipeline_summary_writer( mock_summary_writer = mock.Mock(return_value=None) - with mock.patch( - "axlearn.common.summary_writer.SummaryWriter.Config.instantiate", - mock.MagicMock(return_value=mock_summary_writer), - ), tempfile.TemporaryDirectory() as local_tmp_dir: + with ( + mock.patch( + "axlearn.common.summary_writer.SummaryWriter.Config.instantiate", + mock.MagicMock(return_value=mock_summary_writer), + ), + tempfile.TemporaryDirectory() as local_tmp_dir, + ): root_dir = local_tmp_dir if local_run else "gs://axlearn-public/testdata/inference_test" with set_data_dir(root_dir): prng_key = jax.random.PRNGKey(11) diff --git a/axlearn/common/input_tf_data_test.py b/axlearn/common/input_tf_data_test.py index 84c2c83b..28ef8dd2 100644 --- a/axlearn/common/input_tf_data_test.py +++ b/axlearn/common/input_tf_data_test.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Tests tf.data inputs.""" + # pylint: disable=no-self-use,too-many-lines import os import tempfile @@ -561,11 +562,7 @@ def _check_iterator_saveable(self, iterator: Iterable): with tempfile.TemporaryDirectory() as td: save_dir = os.path.join(td, "ckpt") step = 100 - ckptr = ( - Checkpointer.default_config() - .set(name="ckptr", dir=save_dir) - .instantiate(parent=None) - ) + ckptr = Checkpointer.default_config().set(dir=save_dir).instantiate() with Mesh(jax.devices(), "data"): ckptr.save(step=step, state={"iterator": iterator}, evaler_summaries=None) ckptr.wait_until_finished() diff --git a/axlearn/common/optimizers_test.py b/axlearn/common/optimizers_test.py index 5ed40098..b3db70db 100644 --- a/axlearn/common/optimizers_test.py +++ b/axlearn/common/optimizers_test.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Tests optimization modules.""" + # pylint: disable=no-self-use,too-many-lines import itertools import tempfile @@ -101,7 +102,7 @@ def _mesh(mesh_shape: Sequence[int]): def _checkpointer_config(): - return Checkpointer.default_config().set(name="test", dir=tempfile.mkdtemp()) + return Checkpointer.default_config().set(dir=tempfile.mkdtemp()) class OldSkipClipState(NamedTuple): @@ -861,7 +862,7 @@ def test_gradient_skipping_backward_compatibility(self): with _mesh(mesh_shape): cfg = _checkpointer_config() cfg.save_policy.min_step = 0 - ckpt: Checkpointer = cfg.instantiate(parent=None) + ckpt: Checkpointer = cfg.instantiate() # Save the older version of state. ckpt.save(step=0, state=prev_state) ckpt.wait_until_finished() diff --git a/axlearn/common/state_builder.py b/axlearn/common/state_builder.py index 3d374bcd..e2ebe003 100644 --- a/axlearn/common/state_builder.py +++ b/axlearn/common/state_builder.py @@ -32,13 +32,13 @@ from axlearn.common.config import ( REQUIRED, ConfigOr, + Configurable, InstantiableConfig, Required, config_class, config_for_function, maybe_instantiate, ) -from axlearn.common.module import Module from axlearn.common.optimizer_base import OptStateSpec from axlearn.common.optimizers import ParamEmaState from axlearn.common.utils import ( @@ -56,7 +56,7 @@ from axlearn.experiments.trainer_config_utils import TrainerConfigFn -class Builder(Module): +class Builder(Configurable): """An abstract class for building trainer states.""" class StateType(enum.Enum): @@ -64,7 +64,7 @@ class StateType(enum.Enum): TENSOR_SPECS = "tensor_specs" @config_class - class Config(Module.Config): + class Config(Configurable.Config): pass @dataclass @@ -110,13 +110,13 @@ class Config(Builder.Config): # If validation is None, no structure check will be enforced. validation: Optional[CheckpointValidationType] = CheckpointValidationType.EXACT - def __init__(self, cfg: Config, *, parent: Optional[Module]): - super().__init__(cfg, parent=parent) + def __init__(self, cfg: Config): + super().__init__(cfg) cfg = self.config self._builders = [] - for i, builder_cfg in enumerate(cfg.builders): - self._builders.append(self._add_child(f"builder{i:03d}", builder_cfg)) + for builder_cfg in cfg.builders: + self._builders.append(builder_cfg.instantiate()) def input_state_type(self) -> Builder.StateType: return self._builders[0].input_state_type() @@ -140,7 +140,7 @@ def __call__(self, state: Builder.State) -> Builder.State: return state -class Converter(Module): +class Converter(Configurable): """Converts builder state from one version to another. This can be used to migrate legacy trainer state structures to updated structures in the case @@ -148,7 +148,7 @@ class Converter(Module): """ @config_class - class Config(Module.Config): + class Config(Configurable.Config): pass def target_state_type(self) -> Builder.StateType: @@ -192,12 +192,12 @@ class ChainConverter(Converter): class Config(Converter.Config): converters: Required[Sequence[Converter.Config]] = REQUIRED - def __init__(self, cfg: Config, *, parent: Optional[Module]): - super().__init__(cfg, parent=parent) + def __init__(self, cfg: Config): + super().__init__(cfg) cfg = self.config self._converters = [] - for i, converter_cfg in enumerate(cfg.converters): - self._converters.append(self._add_child(f"converter{i:03d}", converter_cfg)) + for converter_cfg in cfg.converters: + self._converters.append(converter_cfg.instantiate()) def target_state_type(self) -> Builder.StateType: for converter in self._converters: @@ -265,8 +265,8 @@ class MergeStateConverter(Converter): class Config(Converter.Config): selection_regexes: list[tuple[str, Union[str, MergeStateSelection]]] = [] - def __init__(self, cfg: "MergeStateConverter.Config", *, parent: Optional[Module] = None): - super().__init__(cfg, parent=parent) + def __init__(self, cfg: "MergeStateConverter.Config"): + super().__init__(cfg) self.selection_regexes = [ (re.compile(r), MergeStateSelection(s)) for r, s in cfg.selection_regexes @@ -322,11 +322,11 @@ def spec_to_config(cls, spec: str) -> Config: cfg.builder = cfg.builder.klass.spec_to_config(spec) return cfg - def __init__(self, cfg: Config, *, parent: Optional[Module]): - super().__init__(cfg, parent=parent) + def __init__(self, cfg: Config): + super().__init__(cfg) cfg = self.config - self._add_child("builder", cfg.builder) - self._add_child("converter", cfg.converter) + self.builder = cfg.builder.instantiate() + self.converter = cfg.converter.instantiate() def input_state_type(self) -> Builder.StateType: return self.converter.target_state_type() @@ -669,8 +669,8 @@ class Config(Builder.Config): # axlearn model params. converter: InstantiableConfig = config_for_function(torch_to_axlearn_converter) - def __init__(self, cfg: Config, *, parent: Optional[Module]): - super().__init__(cfg, parent=parent) + def __init__(self, cfg: Config): + super().__init__(cfg) cfg = self.config self._converter = cfg.converter.instantiate() @@ -1200,5 +1200,4 @@ def get_builder(spec: Union[str, Builder]) -> Builder: if isinstance(spec, Builder): return spec builder_cfg = get_builder_config(spec, builders=_BUILDERS) - builder = builder_cfg.set(name="builder").instantiate(parent=None) - return builder + return builder_cfg.instantiate() diff --git a/axlearn/common/state_builder_test.py b/axlearn/common/state_builder_test.py index fb6e39fb..fdbc7a93 100644 --- a/axlearn/common/state_builder_test.py +++ b/axlearn/common/state_builder_test.py @@ -87,14 +87,13 @@ def __call__(self, state: Builder.State) -> Builder.State: ) cfg = RestoreAndConvertBuilder.default_config().set( - name="test", builder=DummyBuilder.default_config().set(), # Tests both string and enum arguments converter=MergeStateConverter.default_config().set( selection_regexes=[(r"b/a", MergeStateSelection.SOURCE), (r".*", "TARGET")] ), ) - builder = cfg.instantiate(parent=None) + builder = cfg.instantiate() state = Builder.State( step=0, trainer_state={"a": jnp.array(0), "b": {"a": jnp.array(0)}}, built_keys=set() @@ -117,11 +116,10 @@ def __call__(self, state: Builder.State) -> Builder.State: ) cfg = RestoreAndConvertBuilder.default_config().set( - name="test", builder=DummyBuilder.default_config().set(), converter=MergeStateConverter.default_config(), ) - builder = cfg.instantiate(parent=None) + builder = cfg.instantiate() state = Builder.State( step=0, trainer_state={"a": jnp.array(0), "b": {"a": jnp.array(0)}}, built_keys=set() @@ -166,7 +164,6 @@ def __call__(self, state: Builder.State) -> Builder.State: num_converters = 5 cfg = RestoreAndConvertBuilder.default_config().set( - name="test", converter=ChainConverter.default_config().set( converters=[ DummyConverter.default_config().set(id=i) for i in range(num_converters) @@ -174,7 +171,7 @@ def __call__(self, state: Builder.State) -> Builder.State: ), builder=DummyBuilder.default_config().set(expected_step=num_converters), ) - builder = cfg.instantiate(parent=None) + builder = cfg.instantiate() state = Builder.State(step=0, trainer_state={}, built_keys=set()) state = builder(state) @@ -229,7 +226,7 @@ def __call__(self, state: Builder.State) -> Builder.State: model = ( Linear.default_config() - .set(name="model", input_dim=1, output_dim=1) + .set(name="test", input_dim=1, output_dim=1) .instantiate(parent=None) ) prng_key = jax.random.PRNGKey(0) @@ -241,10 +238,9 @@ def __call__(self, state: Builder.State) -> Builder.State: builder = ( ChainBuilder.default_config() .set( - name="builder", builders=[BiasPlusOneBuilder.default_config(), BiasPlusOneBuilder.default_config()], ) - .instantiate(parent=None) + .instantiate() ) new_trainer_state = builder( Builder.State(step=0, trainer_state=init_trainer_state, built_keys=set()) @@ -265,7 +261,7 @@ def __call__(self, state: Builder.State) -> Builder.State: model = ( Linear.default_config() - .set(name="model", input_dim=1, output_dim=1) + .set(name="test", input_dim=1, output_dim=1) .instantiate(parent=None) ) prng_key = jax.random.PRNGKey(0) @@ -276,8 +272,8 @@ def __call__(self, state: Builder.State) -> Builder.State: ) builder = ( ChainBuilder.default_config() - .set(name="builder", builders=[RemoveBiasBuilder.default_config()]) - .instantiate(parent=None) + .set(builders=[RemoveBiasBuilder.default_config()]) + .instantiate() ) with self.assertRaises(ValueError): builder(Builder.State(step=0, trainer_state=init_trainer_state, built_keys=set())) @@ -318,11 +314,10 @@ def test_truncation(self, strategy): _, target_state = self._mock_bert_trainer_config_and_state(max_len=64) cfg = PosEmbeddingConverter.default_config().set( - name="pos_emb_tester", source_trainer_config=source_trainer_config, strategy=strategy, ) - converter: PosEmbeddingConverter = cfg.instantiate(parent=None) + converter: PosEmbeddingConverter = cfg.instantiate() converted_state = converter.source_to_target(source_state, target_state) self.assertEqual( @@ -351,7 +346,6 @@ def test_truncation_target_longer_than_source(self, strategy): source_trainer_config, source_state = self._mock_bert_trainer_config_and_state(max_len=32) _, target_state = self._mock_bert_trainer_config_and_state(max_len=64) cfg = PosEmbeddingConverter.default_config().set( - name="pos_emb_tester", source_trainer_config=source_trainer_config, strategy=strategy, ) @@ -359,7 +353,7 @@ def test_truncation_target_longer_than_source(self, strategy): with self.assertRaisesWithLiteralMatch( ValueError, "Target length 64 must be <= source len 32." ): - converter: PosEmbeddingConverter = cfg.instantiate(parent=None) + converter: PosEmbeddingConverter = cfg.instantiate() converter.source_to_target(source_state, target_state) @parameterized.parameters( @@ -371,7 +365,6 @@ def test_truncation_incompatible_shape(self, strategy: str): ) _, target_state = self._mock_bert_trainer_config_and_state(max_len=64, hidden_dim=256) cfg = PosEmbeddingConverter.default_config().set( - name="pos_emb_tester", source_trainer_config=source_trainer_config, strategy=strategy, ) @@ -379,7 +372,7 @@ def test_truncation_incompatible_shape(self, strategy: str): with self.assertRaisesWithLiteralMatch( ValueError, "Incompatible shapes: source (1, 512, 128) vs. target (1, 64, 256)." ): - converter: PosEmbeddingConverter = cfg.instantiate(parent=None) + converter: PosEmbeddingConverter = cfg.instantiate() converter.source_to_target(source_state, target_state) def test_replace_target_prefix_with_source(self): @@ -387,11 +380,10 @@ def test_replace_target_prefix_with_source(self): _, target_state = self._mock_bert_trainer_config_and_state(max_len=64) cfg = PosEmbeddingConverter.default_config().set( - name="pos_emb_tester", source_trainer_config=source_trainer_config, strategy="replace_target_prefix_with_source", ) - converter: PosEmbeddingConverter = cfg.instantiate(parent=None) + converter: PosEmbeddingConverter = cfg.instantiate() converted_state = converter.source_to_target(source_state, target_state) self.assertEqual( @@ -422,11 +414,10 @@ def test_keep_target(self): _, target_state = self._mock_bert_trainer_config_and_state(max_len=64) cfg = PosEmbeddingConverter.default_config().set( - name="pos_emb_tester", source_trainer_config=source_trainer_config, strategy="keep_target", ) - converter: PosEmbeddingConverter = cfg.instantiate(parent=None) + converter: PosEmbeddingConverter = cfg.instantiate() converted_state = converter.source_to_target(source_state, target_state) target_weight = replicate_to_local_data( @@ -442,7 +433,6 @@ def test_non_existent_strategy(self): source_trainer_config, source_state = self._mock_bert_trainer_config_and_state(max_len=512) _, target_state = self._mock_bert_trainer_config_and_state(max_len=64) cfg = PosEmbeddingConverter.default_config().set( - name="pos_emb_tester", source_trainer_config=source_trainer_config, strategy="nonexistent", ) @@ -451,7 +441,7 @@ def test_non_existent_strategy(self): NotImplementedError, "Strategy nonexistent is not implemented for PosEmbeddingConverter.", ): - converter: PosEmbeddingConverter = cfg.instantiate(parent=None) + converter: PosEmbeddingConverter = cfg.instantiate() converter.source_to_target(source_state, target_state) @@ -673,10 +663,9 @@ def _run_builder( builder = ( builder_cls.default_config() .set( - name="builder", **extra_converter_config_kwargs, ) - .instantiate(parent=None) + .instantiate() ) source_model = source_state.trainer_state.model @@ -747,7 +736,7 @@ def test_mesh_shape(self): mesh_shape=(-1, 1), ) # Ensure that we're able to instantiate mock_trainer_cfg with -1 in the mesh. - converter = cfg.set(name="test_converter").instantiate(parent=None) + converter = cfg.instantiate() converter.target_to_source(mock_state) @@ -826,11 +815,10 @@ def flax_state_supplier(): return source_params builder_config = FlaxPretrainedBuilder.default_config().set( - name="builder", flax_state_supplier_config=config_for_function(flax_state_supplier), target_scope=[], ) - builder = builder_config.instantiate(parent=None) + builder = builder_config.instantiate() restored_state = builder(builder_state) @@ -1001,13 +989,12 @@ def dummy_layer(): return x cfg = HuggingFacePreTrainedBuilder.default_config().set( - name="test", hf_layer_config=config_for_function(dummy_layer), ) if dst_layer is not None: cfg.converter.dst_layer = dst_layer - builder = cfg.instantiate(parent=None) + builder = cfg.instantiate() init_state = TrainerState( model=dict(weight=jnp.zeros([5, 2]), bias=jnp.zeros([2])), prng_key=None, learner=None ) @@ -1055,8 +1042,8 @@ def test_scope_none(self): _, target_state = _create_dummy_state(jax.random.PRNGKey(1)) converter: ModelStateScopeConverter = ( ModelStateScopeConverter.default_config() - .set(name="test", source_trainer_config=source_cfg) - .instantiate(parent=None) + .set(source_trainer_config=source_cfg) + .instantiate() ) # The pruned state has no learner entries. pruned_source_state, _ = converter.target_to_source(target_state) @@ -1093,8 +1080,8 @@ def test_target_scope_only(self): _, target_state = _create_dummy_state(jax.random.PRNGKey(1), target_model_cfg) converter = ( ModelStateScopeConverter.default_config() - .set(name="test", source_trainer_config=source_cfg, scope="nested") - .instantiate(parent=None) + .set(source_trainer_config=source_cfg, scope="nested") + .instantiate() ) converted_state = converter.source_to_target(source_state, target_state) self.assertNestedAllClose( @@ -1109,8 +1096,8 @@ def test_cross_scopes(self): ) converter = ( ModelStateScopeConverter.default_config() - .set(name="test", source_trainer_config=source_cfg, scope={"linear2": "linear"}) - .instantiate(parent=None) + .set(source_trainer_config=source_cfg, scope={"linear2": "linear"}) + .instantiate() ) converted_state = converter.source_to_target(source_state, target_state) self.assertNestedAllClose( @@ -1126,11 +1113,10 @@ def test_cross_scopes_many(self): converter = ( ModelStateScopeConverter.default_config() .set( - name="test", source_trainer_config=source_cfg, scope={"linear2/bias": "linear/bias", "linear2/weight": "linear/weight"}, ) - .instantiate(parent=None) + .instantiate() ) converted_state = converter.source_to_target(source_state, target_state) self.assertNestedAllClose( @@ -1146,11 +1132,10 @@ def test_only_linear_weight(self): converter = ( ModelStateScopeConverter.default_config() .set( - name="test", source_trainer_config=source_cfg, scope={"linear2/weight": "linear/weight"}, ) - .instantiate(parent=None) + .instantiate() ) # The pruned state has only 'linear/weight' under 'model'. pruned_source_state, _ = converter.target_to_source(target_state) @@ -1185,11 +1170,10 @@ def _create_fake_state_and_convert(self, scope_mapping: Dict[str, str]): converter = ( ModelStateScopeConverter.default_config() .set( - name="test", source_trainer_config=source_cfg, scope=scope_mapping, ) - .instantiate(parent=None) + .instantiate() ) converted_state = converter.source_to_target(source_state, target_state) return source_state, converted_state @@ -1308,8 +1292,8 @@ def test_source_data_dir(self, source_data_dir): _, target_state = _create_dummy_state(jax.random.PRNGKey(1)) converter: ModelStateScopeConverter = ( ModelStateScopeConverter.default_config() - .set(name="test", source_trainer_config=source_cfg, source_data_dir=source_data_dir) - .instantiate(parent=None) + .set(source_trainer_config=source_cfg, source_data_dir=source_data_dir) + .instantiate() ) converted_state = converter.source_to_target(source_state, target_state) self.assertNestedAllClose( @@ -1330,13 +1314,7 @@ def test_ema_params_converter(self, target_ema): elif target_ema == "with_learner_no_ema": del target_state.trainer_state.learner["ema"] - converter = ( - EmaParamsConverter.default_config() - .set( - name="test", - ) - .instantiate(parent=None) - ) + converter = EmaParamsConverter.default_config().instantiate() convert_state, _ = converter.target_to_source(target_state) # Test that model is empty. self.assertNestedAllClose( diff --git a/axlearn/common/summary_test.py b/axlearn/common/summary_test.py index 0deb2ff4..70e6b70b 100644 --- a/axlearn/common/summary_test.py +++ b/axlearn/common/summary_test.py @@ -35,9 +35,7 @@ class SummaryTest(TestCase): def test_add_summary_image(self): tempdir = tempfile.mkdtemp() - writer: SummaryWriter = ( - SummaryWriter.default_config().set(name="test", dir=tempdir).instantiate(parent=None) - ) + writer: SummaryWriter = SummaryWriter.default_config().set(dir=tempdir).instantiate() color_image = jax.numpy.ones((2, 5, 5, 3)) grayscale_image = jax.numpy.zeros((2, 5, 5)) writer( @@ -269,8 +267,8 @@ def _test(): try: writer: WandBWriter = ( WandBWriter.default_config() - .set(name="test", exp_name="wandb-testAddSummary", dir=tempdir, mode="offline") - .instantiate(parent=None) + .set(exp_name="wandb-testAddSummary", dir=tempdir, mode="offline") + .instantiate() ) output_collection = _test() diff --git a/axlearn/common/summary_writer.py b/axlearn/common/summary_writer.py index 0165d32e..8c67acbd 100644 --- a/axlearn/common/summary_writer.py +++ b/axlearn/common/summary_writer.py @@ -17,8 +17,14 @@ from tensorflow import summary as tf_summary from axlearn.common import file_system as fs -from axlearn.common.config import REQUIRED, ConfigBase, Required, RequiredFieldValue, config_class -from axlearn.common.module import Module +from axlearn.common.config import ( + REQUIRED, + ConfigBase, + Configurable, + Required, + RequiredFieldValue, + config_class, +) from axlearn.common.summary import AudioSummary, ImageSummary, Summary from axlearn.common.utils import NestedTensor, Tensor, tree_paths @@ -57,11 +63,11 @@ def wrapper(*args: Any, **kwargs: Any) -> Optional[Any]: return wrapper -class BaseWriter(Module): +class BaseWriter(Configurable): """Base summary writer.""" @config_class - class Config(Module.Config): + class Config(Configurable.Config): """Configures BaseWriter.""" dir: Required[str] = REQUIRED # The output directory. @@ -115,14 +121,14 @@ class CompositeWriter(BaseWriter): class Config(BaseWriter.Config): writers: Required[dict[str, BaseWriter.Config]] = REQUIRED - def __init__(self, cfg: Config, *, parent: Optional["Module"]): - super().__init__(cfg, parent=parent) + def __init__(self, cfg: Config): + super().__init__(cfg) cfg = self.config self._writers: list[BaseWriter] = [] for writer_name, writer_cfg in cfg.writers.items(): self._writers.append( - self._add_child(writer_name, writer_cfg.set(dir=os.path.join(cfg.dir, writer_name))) + writer_cfg.set(dir=os.path.join(cfg.dir, writer_name)).instantiate() ) @property @@ -183,8 +189,8 @@ class Config(BaseWriter.Config): max_queue: Optional[int] = None flush_ms: Optional[float] = None - def __init__(self, cfg: BaseWriter.Config, *, parent: Optional[Module]): - super().__init__(cfg, parent=parent) + def __init__(self, cfg: BaseWriter.Config): + super().__init__(cfg) cfg: SummaryWriter.Config = self.config self.summary_writer: tf_summary.SummaryWriter = ( tf_summary.create_file_writer( @@ -209,7 +215,7 @@ def log_config(self, config: ConfigBase, step: int = 0): tf_summary.text(f"trainer_config/{parts[0]}", parts[1], step=step) def __call__(self, step: int, values: dict[str, Any]): - cfg = self.config + cfg: SummaryWriter.Config = self.config if step % cfg.write_every_n_steps != 0: return @@ -221,7 +227,7 @@ def write(path: str, value: jax.Array): else: raw_value = value - self.vlog(3, "SummaryWriter %s: %s=%s", self.path(), path, raw_value) + logging.debug("SummaryWriter %s: %s=%s", self.__class__.__name__, path, raw_value) if isinstance(raw_value, Tensor) and not raw_value.is_fully_replicated: logging.warning( @@ -306,7 +312,7 @@ class Config(BaseWriter.Config): convert_2d_to_image: bool = False @classmethod - def default_config(cls: Config) -> Config: + def default_config(cls) -> Config: cfg = super().default_config() cfg.exp_name = os.environ.get("WANDB_NAME") cfg.project = os.environ.get("WANDB_PROJECT") @@ -318,13 +324,13 @@ def default_config(cls: Config) -> Config: cfg.dir = os.environ.get("WANDB_DIR") return cfg - def __init__(self, cfg: SummaryWriter.Config, *, parent: Optional[Module]): + def __init__(self, cfg: SummaryWriter.Config): if wandb is None: raise ModuleNotFoundError( "To use the Weights & Biases logger, please install wandb " "with `pip install wandb`." ) - super().__init__(cfg, parent=parent) + super().__init__(cfg) if wandb.run is None: self._initialize_run() @@ -386,7 +392,7 @@ def convert(path: str, value: Any): else: raw_value = value - self.vlog(3, "WandbWriter %s: %s=%s", self.path(), path, raw_value) + logging.debug("WandbWriter %s: %s=%s", self.__class__.__name__, path, raw_value) # Ensure all arrays are cast to numpy. # Wandb will crash if jax.Array is present. diff --git a/axlearn/common/summary_writer_test.py b/axlearn/common/summary_writer_test.py index 8962545d..ee2974ad 100644 --- a/axlearn/common/summary_writer_test.py +++ b/axlearn/common/summary_writer_test.py @@ -7,6 +7,7 @@ WANDB_API_KEY="..." pytest summary_writer_test.py """ + import os import tempfile from unittest import mock @@ -39,8 +40,8 @@ class SummaryWriterTest(absltest.TestCase): def test_add_summary(self): with tempfile.TemporaryDirectory() as tempdir: - cfg: SummaryWriter.Config = SummaryWriter.default_config().set(name="test", dir=tempdir) - writer = cfg.instantiate(parent=None) + cfg: SummaryWriter.Config = SummaryWriter.default_config().set(dir=tempdir) + writer = cfg.instantiate() writer( step=100, values={ @@ -67,8 +68,8 @@ def test_add_summary(self): def test_log_config(self): with tempfile.TemporaryDirectory() as tempdir: - cfg: SummaryWriter.Config = SummaryWriter.default_config().set(name="test", dir=tempdir) - writer = cfg.instantiate(parent=None) + cfg: SummaryWriter.Config = SummaryWriter.default_config().set(dir=tempdir) + writer = cfg.instantiate() writer.log_config(DummyModel.default_config()) event_acc = EventAccumulator(tempdir, size_guidance={"tensors": 0}) event_acc.Reload() @@ -106,14 +107,13 @@ def test_multiple_summary_writers(self): writer = ( CompositeWriter.default_config() .set( - name="test_multi_writer", dir=tempdir, writers={ "writer1": SummaryWriter.default_config(), "writer2": SummaryWriter.default_config(), }, ) - .instantiate(parent=None) + .instantiate() ) writer( step=100, @@ -129,14 +129,13 @@ def test_multiple_summary_writers_checkpoint(self): writer = ( CompositeWriter.default_config() .set( - name="test_multi_writer", dir=tempdir, writers={ "writer1": SummaryWriter.default_config(), "writer2": SummaryWriter.default_config(), }, ) - .instantiate(parent=None) + .instantiate() ) for sub_writer in writer.writers: sub_writer.log_checkpoint = mock.Mock() @@ -173,8 +172,8 @@ def test_add_summary(self): try: writer: WandBWriter = ( WandBWriter.default_config() - .set(name="test", exp_name="wandb-testAddSummary", dir=tempdir, mode="offline") - .instantiate(parent=None) + .set(exp_name="wandb-testAddSummary", dir=tempdir, mode="offline") + .instantiate() ) for step in [10, 20, 30, 40]: self._write_per_step(writer, step) @@ -194,8 +193,8 @@ def test_resume(self): try: writer: WandBWriter = ( WandBWriter.default_config() - .set(name="test", exp_name="wandb-testResume", dir=tempdir) - .instantiate(parent=None) + .set(exp_name="wandb-testResume", dir=tempdir) + .instantiate() ) exp_id = wandb.run.id @@ -205,8 +204,8 @@ def test_resume(self): writer: WandBWriter = ( WandBWriter.default_config() - .set(name="test", exp_name="wandb-testResume", dir=tempdir) - .instantiate(parent=None) + .set(exp_name="wandb-testResume", dir=tempdir) + .instantiate() ) assert wandb.run.id == exp_id # Because we resume from checkpoints, we may compute metrics diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index 8f618c69..8105daf7 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -247,13 +247,16 @@ def __init__( cfg.summary_writer.dir = cfg.summary_writer.dir or os.path.join( cfg.dir, "summaries", "train_train" ) - self._add_child("summary_writer", cfg.summary_writer) self._add_child("model", cfg.model) self._add_child("learner", cfg.learner) + + # Instantiate non-Module children. cfg.checkpointer.dir = cfg.checkpointer.dir or os.path.join(cfg.dir, "checkpoints") - self._add_child("checkpointer", cfg.checkpointer) - if cfg.init_state_builder is not None: - self._add_child("init_state_builder", cfg.init_state_builder) + self.checkpointer = cfg.checkpointer.instantiate() + self.summary_writer = cfg.summary_writer.instantiate() + self.init_state_builder: Optional[TrainerStateBuilder] = maybe_instantiate( + cfg.init_state_builder + ) self._model_param_specs = self.model.create_parameter_specs_recursively() model_param_partition_specs = jax.tree.map( @@ -524,7 +527,7 @@ def init(self, prng_key: Tensor): Args: prng_key: The initialization key. """ - if "init_state_builder" not in self.children: + if self.init_state_builder is None: self._init_with_prebuilt_state(prng_key, prebuilt_state=None) return input_state_type = self.init_state_builder.input_state_type() diff --git a/axlearn/common/trainer_test.py b/axlearn/common/trainer_test.py index 243c1953..8013dbb4 100644 --- a/axlearn/common/trainer_test.py +++ b/axlearn/common/trainer_test.py @@ -40,7 +40,13 @@ every_n_steps_and_last_policy, every_n_steps_policy, ) -from axlearn.common.config import REQUIRED, Required, config_class, config_for_function +from axlearn.common.config import ( + REQUIRED, + Configurable, + Required, + config_class, + config_for_function, +) from axlearn.common.evaler import SpmdEvaler from axlearn.common.evaler import every_n_steps_policy as eval_every_n_steps_policy from axlearn.common.learner import UpdateType, should_update_with_optimizers @@ -278,7 +284,7 @@ class DummyStateBuilder(TrainerStateBuilder): """A dummy builder that "builds" state from fixed values.""" @config_class - class Config(Module.Config): + class Config(Configurable.Config): step: Required[int] = REQUIRED model_state: Required[Callable[[], NestedTensor]] = REQUIRED input_state_type: Required[TrainerStateBuilder.StateType] = REQUIRED @@ -387,10 +393,8 @@ def test_trainer( input=DummyInput.default_config().set(total_num_batches=2), eval_dtype=step_dtype, ) - evaler_cfg.summary_writer.vlog = 5 cfg.evalers = dict(eval_dummy=evaler_cfg) cfg.checkpointer.save_policy = config_for_function(every_n_steps_policy).set(n=5) - cfg.summary_writer.vlog = 5 cfg.max_step = 12 cfg.watchdog_timeout_seconds = 0.1 cfg.vlog = 2 @@ -473,10 +477,8 @@ def test_return_evaler_summaries(self, return_evaler_summaries): eval_dtype=step_dtype, eval_policy=config_for_function(eval_every_n_steps_policy).set(n=10), ) - evaler_cfg.summary_writer.vlog = 5 cfg.evalers = dict(eval_dummy=evaler_cfg, eval_dummy2=evaler_cfg.clone()) cfg.checkpointer.save_policy = config_for_function(every_n_steps_policy).set(n=5) - cfg.summary_writer.vlog = 5 cfg.max_step = 3 cfg.watchdog_timeout_seconds = 0.1 cfg.vlog = 2 @@ -650,10 +652,8 @@ def test_should_compute_gradients(self, update_rules): evaler_cfg = SpmdEvaler.default_config().set( input=DummyInput.default_config().set(total_num_batches=2), ) - evaler_cfg.summary_writer.vlog = 5 cfg.evalers = dict(eval_dummy=evaler_cfg) cfg.checkpointer.save_policy = config_for_function(every_n_steps_policy).set(n=5) - cfg.summary_writer.vlog = 5 cfg.max_step = 12 trainer: SpmdTrainer = cfg.instantiate(parent=None) with trainer.mesh(): @@ -878,10 +878,8 @@ def test_composite_learner(self): evaler_cfg = SpmdEvaler.default_config().set( input=DummyInput.default_config().set(total_num_batches=2), ) - evaler_cfg.summary_writer.vlog = 5 cfg.evalers = dict(eval_dummy=evaler_cfg) cfg.checkpointer.save_policy = config_for_function(every_n_steps_policy).set(n=5) - cfg.summary_writer.vlog = 5 cfg.max_step = 12 trainer: SpmdTrainer = cfg.instantiate(parent=None) with trainer.mesh():