diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index d5d64a6f..0f4fa0ae 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -7,11 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Add `RootMetadata` and `StepMetadata` classes as ways for the user to +interface with checkpoint metadata at various levels. +- Add `root_metadata_serialization`, and `step_metadata_io` modules that contain utilities +to perform de/serialization for `RootMetadata` and `StepMetadata`. + ### Changed - Create `Composite` class, which `CompositeArgs` now subclasses. - Move `type_handlers` to `_src/serialization` - Add notes to Barrier error `XlaRuntimeError(DEADLINE_EXCEEDED)` with actionable info. +- Rename `CheckpointMetadataStore` to `MetadataStore`, and change methods to +accept and return metadata as dictionaries. ## [0.8.0] - 2024-10-29 diff --git a/checkpoint/orbax/checkpoint/_src/metadata/checkpoint.py b/checkpoint/orbax/checkpoint/_src/metadata/checkpoint.py index df7aefdf..9db98e90 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/checkpoint.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/checkpoint.py @@ -12,56 +12,87 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Manages metadata of checkpoints at step level (not item level).""" +"""Manages metadata of checkpoints at root and step level (not item level).""" from __future__ import annotations import concurrent.futures import dataclasses import json import threading -from typing import Any, Optional, Protocol +from typing import Any, Protocol, TypeVar from absl import logging from etils import epath +from orbax.checkpoint._src import composite +from orbax.checkpoint._src.logging import step_statistics _METADATA_FILENAME = '_CHECKPOINT_METADATA' +_ROOT_METADATA_DIRNAME = 'metadata' +ItemMetadata = composite.Composite +StepStatistics = step_statistics.SaveStepStatistics +MetadataDict = TypeVar('MetadataDict', bound=dict[str, Any]) -def metadata_file_path(path: epath.PathLike) -> epath.Path: - """Returns the path to metadata file for a given checkpoint directory.""" - return epath.Path(path) / _METADATA_FILENAME + +@dataclasses.dataclass +class Metadata: + """Metadata of a checkpoint at either step or root level.""" + ... @dataclasses.dataclass -class StepMetadata: +class StepMetadata(Metadata): """Metadata of a checkpoint at step level (not per item). - NOTE: Internal class. Please reach out to Orbax team if you want to use it in - your codebase. - Attributes: + format: The checkpoint file format. + item_handlers: Map of item name to its checkpoint handler. + item_metadata: Map of item name to its metadata. + metrics: User-provided metrics (accuracy, loss, etc.) + performance_metrics: Performance metrics (time, memory, etc.) init_timestamp_nsecs: timestamp when uncommitted checkpoint was initialized. Specified as nano seconds since epoch. default=None. commit_timestamp_nsecs: commit timestamp of a checkpoint, specified as nano seconds since epoch. default=None. + custom: User-provided custom metadata. + """ + + format: str | None = None + item_handlers: dict[str, str] = dataclasses.field(default_factory=dict) + item_metadata: ItemMetadata | None = None + metrics: dict[str, Any] = dataclasses.field(default_factory=dict) + performance_metrics: StepStatistics = dataclasses.field( + default_factory=StepStatistics + ) + init_timestamp_nsecs: int | None = None + commit_timestamp_nsecs: int | None = None + custom: dict[str, Any] = dataclasses.field(default_factory=dict) + + @classmethod + def file_path(cls, path: epath.PathLike) -> epath.Path: + """The path to step metadata file for a given checkpoint directory.""" + return epath.Path(path) / _METADATA_FILENAME + + +@dataclasses.dataclass +class RootMetadata(Metadata): + """Metadata of a checkpoint at root level (contains all steps). + + Attributes: + format: The checkpoint file format. + custom: User-provided custom metadata. """ - init_timestamp_nsecs: Optional[int] = None - commit_timestamp_nsecs: Optional[int] = None + format: str | None = None + custom: dict[str, Any] = dataclasses.field(default_factory=dict) @classmethod - def from_dict(cls, dict_data: Any) -> StepMetadata: - validated_dict = {} - if 'init_timestamp_nsecs' in dict_data: - validated_dict['init_timestamp_nsecs'] = dict_data['init_timestamp_nsecs'] - if 'commit_timestamp_nsecs' in dict_data: - validated_dict['commit_timestamp_nsecs'] = dict_data[ - 'commit_timestamp_nsecs' - ] - return StepMetadata(**validated_dict) + def file_path(cls, path: epath.PathLike) -> epath.Path: + """The path to root metadata file for a given checkpoint directory.""" + return epath.Path(path) / _ROOT_METADATA_DIRNAME / _METADATA_FILENAME -class CheckpointMetadataStore(Protocol): - """Manages storage of `CheckpointMetadata`.""" +class MetadataStore(Protocol): + """Manages storage of `Metadata`.""" def is_blocking_writer(self) -> bool: """Returns True if the store performs blocking writes, otherwise False.""" @@ -69,32 +100,32 @@ def is_blocking_writer(self) -> bool: def write( self, - checkpoint_path: epath.PathLike, - checkpoint_metadata: StepMetadata, + file_path: epath.PathLike, + metadata: MetadataDict, ) -> None: - """[Over]Writes `checkpoint_metadata` to `checkpoint_path`/*metadata_file*.""" + """[Over]Writes `metadata` to `file`.""" ... def read( - self, checkpoint_path: epath.PathLike - ) -> Optional[StepMetadata]: - """Reads `checkpoint_path`/*metadata_file* and returns `CheckpointMetadata`.""" + self, file_path: epath.PathLike + ) -> MetadataDict | None: + """Reads `file` and returns `Metadata` dict.""" ... def update( self, - checkpoint_path: epath.PathLike, + file_path: epath.PathLike, **kwargs, ) -> None: - """Safely updates CheckpointMetadata at `checkpoint_path`/*metadata_file*. + """Safely updates `Metadata` at `file`. - If no updatable CheckpointMetadata is found at - `checkpoint_path`/*metadata_file*, then it creates a new one with `kwargs` - attributes. + If no updatable `Metadata` is found at `file`, then it creates a + new one with `kwargs` attributes. Args: - checkpoint_path: path to checkpoint dir (step dir). - **kwargs: Attributes of CheckpointMetadata is kwargs format. + file_path: path to metadata file in checkpoint dir (for + `RootMetadata`) or step dir (for `StepMetadata`). + **kwargs: Attributes of `Metadata` in kwargs format. """ ... @@ -107,8 +138,8 @@ def close(self) -> None: ... -class _CheckpointMetadataStoreImpl(CheckpointMetadataStore): - """Basic internal reusable impl of `CheckpointMetadata` storage. +class _MetadataStoreImpl(MetadataStore): + """Basic internal reusable impl of `Metadata` storage. It is neither thread safe, nor does it check for read/write capabilities. """ @@ -118,42 +149,42 @@ def is_blocking_writer(self) -> bool: def write( self, - checkpoint_path: epath.PathLike, - checkpoint_metadata: StepMetadata, + file_path: epath.PathLike, + metadata: MetadataDict, ) -> None: - checkpoint_path = epath.Path(checkpoint_path) - if not checkpoint_path.exists(): - raise ValueError(f'Checkpoint path does not exist: {checkpoint_path}') - json_data = json.dumps(dataclasses.asdict(checkpoint_metadata)) - bytes_written = metadata_file_path(checkpoint_path).write_text(json_data) + metadata_file = epath.Path(file_path) + if not metadata_file.parent.exists(): + raise ValueError(f'Metadata path does not exist: {metadata_file.parent}') + json_data = json.dumps(metadata) + bytes_written = metadata_file.write_text(json_data) if bytes_written == 0: raise ValueError( - f'Failed to write CheckpointMetadata={checkpoint_metadata},' - f' json={json_data} to {checkpoint_path}' + f'Failed to write Metadata={metadata},' + f' json={json_data} to {metadata_file}' ) logging.log_every_n( logging.INFO, - 'Wrote CheckpointMetadata=%s, json=%s to %s', + 'Wrote Metadata=%s, json=%s to %s', 100, - checkpoint_metadata, + metadata, json_data, - checkpoint_path, + metadata_file, ) def read( - self, checkpoint_path: epath.PathLike - ) -> Optional[StepMetadata]: - metadata_file = metadata_file_path(checkpoint_path) + self, file_path: epath.PathLike + ) -> MetadataDict | None: + metadata_file = epath.Path(file_path) if not metadata_file.exists(): logging.warning( - 'CheckpointMetadata file does not exist: %s', metadata_file + 'Metadata file does not exist: %s', metadata_file ) return None try: raw_data = metadata_file.read_text() except Exception as e: # pylint: disable=broad-exception-caught logging.error( - 'Failed to read CheckpointMetadata file: %s, error: %s', + 'Failed to read Metadata file: %s, error: %s', metadata_file, e, ) @@ -163,43 +194,45 @@ def read( except json.decoder.JSONDecodeError as e: # TODO(b/340287956): Found empty metadata files, how is it possible. logging.error( - 'Failed to json parse CheckpointMetadata file: %s, file content: %s,' - ' error: %s', + 'Failed to json parse Metadata file: %s, ' + 'file content: %s, ' + 'error: %s', metadata_file, raw_data, e, ) return None - result = StepMetadata.from_dict(json_data) logging.log_every_n( logging.INFO, - 'Read CheckpointMetadata=%s from %s', + 'Read Metadata=%s from %s', 500, - result, - checkpoint_path, + json_data, + metadata_file, ) - return result + return json_data def update( self, - checkpoint_path: epath.PathLike, + file_path: epath.PathLike, **kwargs, ) -> None: - metadata = self.read(checkpoint_path) or StepMetadata() - updated = dataclasses.replace(metadata, **kwargs) - self.write(checkpoint_path, updated) + metadata_file = epath.Path(file_path) + metadata = self.read(metadata_file) or {} + for k, v in kwargs.items(): + metadata[k] = v + self.write(metadata_file, metadata) logging.log_every_n( logging.INFO, - 'Updated CheckpointMetadata=%s to %s', + 'Updated Metadata=%s to %s', 100, - updated, - checkpoint_path, + metadata, + metadata_file, ) @dataclasses.dataclass -class _BlockingCheckpointMetadataStore(CheckpointMetadataStore): - """Manages storage of `CheckpointMetadata` with blocking writes. +class _BlockingMetadataStore(MetadataStore): + """Manages storage of `Metadata` with blocking writes. Write operations are thread safe: within a process multiple threads write without corrupting data. @@ -215,12 +248,12 @@ class _BlockingCheckpointMetadataStore(CheckpointMetadataStore): """ enable_write: bool - # TODO(niketkb): Support locking per checkpoint path. + # TODO(niketkb): Support locking per path. _write_lock: threading.RLock = dataclasses.field(init=False) - _store_impl: _CheckpointMetadataStoreImpl = dataclasses.field(init=False) + _store_impl: _MetadataStoreImpl = dataclasses.field(init=False) def __post_init__(self): - self._store_impl = _CheckpointMetadataStoreImpl() + self._store_impl = _MetadataStoreImpl() if self.enable_write: self._write_lock = threading.RLock() @@ -229,44 +262,44 @@ def is_blocking_writer(self) -> bool: def write( self, - checkpoint_path: epath.PathLike, - checkpoint_metadata: StepMetadata, + file_path: epath.PathLike, + metadata: MetadataDict, ) -> None: if not self.enable_write: return with self._write_lock: - self._store_impl.write(checkpoint_path, checkpoint_metadata) + self._store_impl.write(file_path, metadata) def read( - self, checkpoint_path: epath.PathLike - ) -> Optional[StepMetadata]: - return self._store_impl.read(checkpoint_path) + self, file_path: epath.PathLike + ) -> MetadataDict | None: + return self._store_impl.read(file_path) def update( self, - checkpoint_path: epath.PathLike, + file_path: epath.PathLike, **kwargs, ) -> None: if not self.enable_write: return with self._write_lock: - self._store_impl.update(checkpoint_path, **kwargs) + self._store_impl.update(file_path, **kwargs) @dataclasses.dataclass -class _NonBlockingCheckpointMetadataStore(CheckpointMetadataStore): - """Manages storage of `CheckpointMetadata` with non blocking writes. +class _NonBlockingMetadataStore(MetadataStore): + """Manages storage of `Metadata` with non blocking writes. - By default it behaves like a read only `CheckpointMetadataStore`. But the same - instance is reused if user requests for a write-enabled instance in the same - process. + By default it behaves like a read only `MetadataStore`. But the + same instance is reused if user requests for a write-enabled instance in the + same process. The writes are non blocking. Read responses don't reflect in progress writes. """ enable_write: bool _write_lock: threading.RLock = dataclasses.field(init=False) - _store_impl: _CheckpointMetadataStoreImpl = dataclasses.field(init=False) + _store_impl: _MetadataStoreImpl = dataclasses.field(init=False) # We need to make sure that only one thread writes/updates to a given path. # A single threaded executor is a simple solution. We can improve it by # introducing a multi threaded executor but setting up tasks such that all @@ -281,7 +314,7 @@ class _NonBlockingCheckpointMetadataStore(CheckpointMetadataStore): ) def __post_init__(self): - self._store_impl = _CheckpointMetadataStoreImpl() + self._store_impl = _MetadataStoreImpl() if self.enable_write: self._write_lock = threading.RLock() self._single_thread_executor = concurrent.futures.ThreadPoolExecutor( @@ -303,90 +336,87 @@ def _add_to_write_futures( def _write_and_log( self, - checkpoint_path: epath.PathLike, - checkpoint_metadata: StepMetadata, + file_path: epath.PathLike, + metadata: MetadataDict, ) -> None: - """Writes `checkpoint_metadata` and logs error if any.""" + """Writes `metadata` and logs error if any.""" try: - self._store_impl.write(checkpoint_path, checkpoint_metadata) + self._store_impl.write(file_path, metadata) except Exception as e: # pylint: disable=broad-exception-caught logging.exception( 'Failed to write metadata=%s path=%s: %s', - checkpoint_metadata, - checkpoint_path, + metadata, + file_path, e, ) raise def write( self, - checkpoint_path: epath.PathLike, - checkpoint_metadata: StepMetadata, + file_path: epath.PathLike, + metadata: MetadataDict, ) -> None: - """[Over]Writes `checkpoint_metadata` in non blocking manner.""" + """[Over]Writes `metadata` in non blocking manner.""" if not self.enable_write: logging.warning( - 'Write requested but enable_write=false, checkpoint_metadata=%s' - ' checkpoint_path=%s', - checkpoint_metadata, - checkpoint_path, + 'Write requested but enable_write=false, metadata=%s' + ' path=%s', + metadata, + file_path, ) return with self._write_lock: future = self._single_thread_executor.submit( - self._write_and_log, checkpoint_path, checkpoint_metadata + self._write_and_log, file_path, metadata ) self._add_to_write_futures(future) def read( - self, checkpoint_path: epath.PathLike - ) -> Optional[StepMetadata]: - """Reads `checkpoint_path`/*metadata_file* and returns `CheckpointMetadata`.""" - return self._store_impl.read(checkpoint_path) + self, file_path: epath.PathLike + ) -> MetadataDict | None: + """Reads `file and returns a metadata dictionary.""" + return self._store_impl.read(file_path) - def _update_and_log(self, checkpoint_path: epath.PathLike, **kwargs) -> None: - """Updates checkpoint metadata attributes and logs error if any.""" + def _update_and_log(self, file: epath.PathLike, **kwargs) -> None: + """Updates metadata attributes and logs error if any.""" try: - self._store_impl.update(checkpoint_path, **kwargs) + self._store_impl.update(file, **kwargs) except Exception as e: # pylint: disable=broad-exception-caught logging.exception( 'Failed to update metadata=%s path=%s: %s', kwargs, - checkpoint_path, + file, e, ) raise - def _validate_kwargs(self, **kwargs) -> None: - _ = StepMetadata(**kwargs) - def update( self, - checkpoint_path: epath.PathLike, + file_path: epath.PathLike, **kwargs, ) -> None: - """Updates CheckpointMetadata in non blocking manner. + """Updates `Metadata` in non blocking manner. - If no updatable CheckpointMetadata is found at - `checkpoint_path`/*metadata_file*, then it creates a new one with `kwargs` + If no updatable `Metadata` is found at + `file`, then it creates a new one with `kwargs` attributes. Args: - checkpoint_path: path to checkpoint dir (step dir). - **kwargs: Attributes of CheckpointMetadata is kwargs format. + file_path: path to the metadata file in checkpoint dir (for + `RootMetadata`) or step dir (for `StepMetadata`). + **kwargs: Attributes of `Metadata` is kwargs format. """ if not self.enable_write: logging.warning( 'Update requested but enable_write=false, kwargs=%s' - ' checkpoint_path=%s', + ' file=%s', kwargs, - checkpoint_path, + file_path, ) return with self._write_lock: - self._validate_kwargs(**kwargs) future = self._single_thread_executor.submit( - self._update_and_log, checkpoint_path, **kwargs + self._update_and_log, file_path, **kwargs ) self._add_to_write_futures(future) @@ -407,23 +437,19 @@ def close(self) -> None: logging.info('Closing %s', self) -_CHECKPOINT_METADATA_STORE_FOR_WRITES = _BlockingCheckpointMetadataStore( - enable_write=True -) -_CHECKPOINT_METADATA_STORE_FOR_READS = _BlockingCheckpointMetadataStore( - enable_write=False -) -_CHECKPOINT_METADATA_STORE_NON_BLOCKING_FOR_READS = ( - _NonBlockingCheckpointMetadataStore(enable_write=False) +_METADATA_STORE_FOR_WRITES = _BlockingMetadataStore(enable_write=True) +_METADATA_STORE_FOR_READS = _BlockingMetadataStore(enable_write=False) +_METADATA_STORE_NON_BLOCKING_FOR_READS = ( + _NonBlockingMetadataStore(enable_write=False) ) -def checkpoint_metadata_store( +def metadata_store( *, enable_write: bool, blocking_write: bool = False, -) -> CheckpointMetadataStore: - """Returns `CheckpointMetadataStore` instance based on `enable_write` value. +) -> MetadataStore: + """Returns `MetadataStore` instance based on `enable_write` value. Write operations are thread safe: within a process multiple threads write without corrupting data. @@ -433,8 +459,8 @@ def checkpoint_metadata_store( Read operations are inherently thread safe and *process safe* too. - NOTE: `CheckpointMetadataStore` instance created with `enable_write=True` and - `blocking_write=False` must be closed with `.close()` to release thread + NOTE: `MetadataStore` instance created with `enable_write=True` + and `blocking_write=False` must be closed with `.close()` to release thread resources. Prefer to reuse an instance created for this scenario. Args: @@ -445,9 +471,9 @@ def checkpoint_metadata_store( """ if not blocking_write: if enable_write: - return _NonBlockingCheckpointMetadataStore(enable_write=True) - return _CHECKPOINT_METADATA_STORE_NON_BLOCKING_FOR_READS + return _NonBlockingMetadataStore(enable_write=True) + return _METADATA_STORE_NON_BLOCKING_FOR_READS if enable_write: - return _CHECKPOINT_METADATA_STORE_FOR_WRITES - return _CHECKPOINT_METADATA_STORE_FOR_READS + return _METADATA_STORE_FOR_WRITES + return _METADATA_STORE_FOR_READS diff --git a/checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py b/checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py index 44ede49f..8be7b036 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py @@ -12,273 +12,424 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import pickle -import time +from typing import Any + from absl.testing import absltest from absl.testing import parameterized from etils import epath +from orbax.checkpoint._src.logging import step_statistics from orbax.checkpoint._src.metadata import checkpoint +from orbax.checkpoint._src.metadata import root_metadata_serialization +from orbax.checkpoint._src.metadata import step_metadata_serialization + + +StepMetadata = checkpoint.StepMetadata +RootMetadata = checkpoint.RootMetadata +MetadataStore = checkpoint.MetadataStore +metadata_store = checkpoint.metadata_store +StepStatistics = step_statistics.SaveStepStatistics +SAMPLE_FORMAT = 'sample_format' -class CheckpointMetadataStoreTest(parameterized.TestCase): + +class CheckpointMetadataTest(parameterized.TestCase): def setUp(self): super().setUp() self.directory = epath.Path(self.create_tempdir().full_path) - checkpoint._CHECKPOINT_METADATA_STORE_FOR_WRITES = ( - checkpoint._BlockingCheckpointMetadataStore(enable_write=True) - ) - checkpoint._CHECKPOINT_METADATA_STORE_FOR_READS = ( - checkpoint._BlockingCheckpointMetadataStore(enable_write=False) - ) - self._non_blocking_store_for_writes = checkpoint.checkpoint_metadata_store( + self._non_blocking_store_for_writes = metadata_store( enable_write=True, blocking_write=False ) - checkpoint._CHECKPOINT_METADATA_STORE_NON_BLOCKING_FOR_READS = ( - checkpoint._NonBlockingCheckpointMetadataStore(enable_write=False) - ) def tearDown(self): super().tearDown() - checkpoint._CHECKPOINT_METADATA_STORE_FOR_WRITES.close() - checkpoint._CHECKPOINT_METADATA_STORE_FOR_READS.close() + checkpoint._METADATA_STORE_FOR_WRITES.close() + checkpoint._METADATA_STORE_FOR_READS.close() self._non_blocking_store_for_writes.close() - checkpoint._CHECKPOINT_METADATA_STORE_NON_BLOCKING_FOR_READS.close() + checkpoint._METADATA_STORE_NON_BLOCKING_FOR_READS.close() def write_metadata_store( self, blocking_write: bool - ) -> checkpoint.CheckpointMetadataStore: + ) -> MetadataStore: if blocking_write: - return checkpoint.checkpoint_metadata_store( + return metadata_store( enable_write=True, blocking_write=True ) return self._non_blocking_store_for_writes def read_metadata_store( self, blocking_write: bool - ) -> checkpoint.CheckpointMetadataStore: - return checkpoint.checkpoint_metadata_store( + ) -> MetadataStore: + return metadata_store( enable_write=False, blocking_write=blocking_write ) + def create_metadata( + self, + metadata_cls: StepMetadata | RootMetadata, + metadata_dict: checkpoint.MetadataDict, + item_metadata: checkpoint.ItemMetadata | None = None, + metrics: dict[str, Any] | None = None, + ) -> StepMetadata | RootMetadata: + if metadata_cls == StepMetadata: + return step_metadata_serialization.deserialize( + metadata_dict, + item_metadata=item_metadata, + metrics=metrics, + ) + elif metadata_cls == RootMetadata: + return root_metadata_serialization.deserialize(metadata_dict) + + def serialize_metadata( + self, + metadata: StepMetadata | RootMetadata, + ) -> dict[str, Any]: + if isinstance(metadata, StepMetadata): + return step_metadata_serialization.serialize(metadata) + elif isinstance(metadata, RootMetadata): + return root_metadata_serialization.serialize(metadata) + + def get_metadata( # pylint: disable=dangerous-default-value + self, + metadata_cls: StepMetadata | RootMetadata, + custom: dict[str, Any] = {'a': 1}, + ) -> StepMetadata | RootMetadata: + if metadata_cls == StepMetadata: + return StepMetadata( + format=SAMPLE_FORMAT, + item_handlers={'a': 'b'}, + item_metadata={'a': None}, + metrics={'a': 1}, + performance_metrics=StepStatistics( + step=None, + event_type='save', + reached_preemption=False, + preemption_received_at=1.0, + ), + init_timestamp_nsecs=1, + commit_timestamp_nsecs=1, + custom=custom, + ) + elif metadata_cls == RootMetadata: + return RootMetadata( + format=SAMPLE_FORMAT, + custom=custom, + ) + @parameterized.parameters(True, False) def test_read_unknown_path(self, blocking_write: bool): self.assertIsNone( self.write_metadata_store(blocking_write).read( - checkpoint_path='unknown_checkpoint_path' + file_path='unknown_checkpoint_path' ) ) - @parameterized.parameters(True, False) - def test_write_unknown_path(self, blocking_write: bool): + @parameterized.parameters( + itertools.product( + [True, False], + [StepMetadata, RootMetadata], + ) + ) + def test_write_unknown_file_path( + self, + blocking_write: bool, + metadata_cls: StepMetadata | RootMetadata, + ): + metadata = metadata_cls() + if blocking_write: - with self.assertRaisesRegex(ValueError, 'Checkpoint path does not exist'): + with self.assertRaisesRegex(ValueError, 'Metadata path does not exist'): self.write_metadata_store(blocking_write).write( - checkpoint_path='unknown_checkpoint_path', - checkpoint_metadata=checkpoint.StepMetadata(), + file_path=metadata.file_path('unknown_metadata_path'), + metadata=self.serialize_metadata(metadata), ) else: self.write_metadata_store(blocking_write).write( - checkpoint_path='unknown_checkpoint_path', - checkpoint_metadata=checkpoint.StepMetadata(), + file_path=metadata.file_path('unknown_metadata_path'), + metadata=self.serialize_metadata(metadata), ) try: self.write_metadata_store(blocking_write).wait_until_finished() except ValueError: # We don't want to fail the test because above write's future.result() - # will raise 'ValueError: Checkpoint path does not exist ...'. + # will raise 'ValueError: Metadata file does not exist ...'. pass self.assertIsNone( self.read_metadata_store(blocking_write).read( - checkpoint_path='unknown_checkpoint_path' + file_path=metadata.file_path('unknown_metadata_path'), ) ) - @parameterized.parameters(True, False) - def test_read_default_values(self, blocking_write: bool): - metadata = checkpoint.StepMetadata() + @parameterized.parameters( + itertools.product( + [True, False], + [StepMetadata, RootMetadata], + ) + ) + def test_read_default_values( + self, + blocking_write: bool, + metadata_cls: StepMetadata | RootMetadata, + ): + metadata = metadata_cls() self.write_metadata_store(blocking_write).write( - checkpoint_path=self.directory, checkpoint_metadata=metadata + file_path=metadata.file_path(self.directory), + metadata=self.serialize_metadata(metadata), ) self.write_metadata_store(blocking_write).wait_until_finished() + metadata_dict = self.write_metadata_store(blocking_write).read( + file_path=metadata.file_path(self.directory), + ) self.assertEqual( - self.write_metadata_store(blocking_write).read( - checkpoint_path=self.directory - ), + self.create_metadata(metadata_cls, metadata_dict), metadata, ) - @parameterized.parameters(True, False) - def test_read_with_values(self, blocking_write: bool): - metadata = checkpoint.StepMetadata( - init_timestamp_nsecs=time.time_ns(), - commit_timestamp_nsecs=time.time_ns() + 1, - ) + @parameterized.parameters( + itertools.product( + [True, False], + [StepMetadata, RootMetadata], + ) + ) + def test_read_with_values( + self, + blocking_write: bool, + metadata_cls: StepMetadata | RootMetadata, + ): + metadata = self.get_metadata(metadata_cls) + self.write_metadata_store(blocking_write).write( - checkpoint_path=self.directory, checkpoint_metadata=metadata + file_path=metadata.file_path(self.directory), + metadata=self.serialize_metadata(metadata), ) self.write_metadata_store(blocking_write).wait_until_finished() + metadata_dict = self.write_metadata_store(blocking_write).read( + file_path=metadata.file_path(self.directory), + ) self.assertEqual( - self.write_metadata_store(blocking_write).read( - checkpoint_path=self.directory - ), + self.create_metadata(metadata_cls, metadata_dict), metadata, ) - @parameterized.parameters(True, False) - def test_read_corrupt_json_data(self, blocking_write: bool): - metadata_file = checkpoint.metadata_file_path(self.directory) + @parameterized.parameters( + itertools.product( + [True, False], + [StepMetadata, RootMetadata], + ) + ) + def test_read_corrupt_json_data( + self, + blocking_write: bool, + metadata_cls: StepMetadata | RootMetadata, + ): + metadata = metadata_cls() + metadata_file = metadata.file_path(self.directory) metadata_file.touch() self.assertIsNone( self.write_metadata_store(blocking_write).read( - checkpoint_path=self.directory + file_path=metadata.file_path(self.directory) ) ) - @parameterized.parameters(True, False) - def test_update_without_prior_data(self, blocking_write: bool): + @parameterized.parameters( + itertools.product( + [True, False], + [StepMetadata, RootMetadata], + ) + ) + def test_update_without_prior_data( + self, + blocking_write: bool, + metadata_cls: StepMetadata | RootMetadata, + ): + metadata = metadata_cls() self.write_metadata_store(blocking_write).update( - checkpoint_path=self.directory, - init_timestamp_nsecs=1, - commit_timestamp_nsecs=2, + file_path=metadata.file_path(self.directory), + format=SAMPLE_FORMAT, + custom={'a': 1}, ) self.write_metadata_store(blocking_write).wait_until_finished() + metadata_dict = self.write_metadata_store(blocking_write).read( + file_path=metadata.file_path(self.directory), + ) self.assertEqual( - self.write_metadata_store(blocking_write).read( - checkpoint_path=self.directory - ), - checkpoint.StepMetadata( - init_timestamp_nsecs=1, - commit_timestamp_nsecs=2, + self.create_metadata(metadata_cls, metadata_dict), + metadata_cls( + format=SAMPLE_FORMAT, + custom={'a': 1}, ), ) - @parameterized.parameters(True, False) - def test_update_with_prior_data(self, blocking_write: bool): - metadata = checkpoint.StepMetadata(init_timestamp_nsecs=1) + @parameterized.parameters( + itertools.product( + [True, False], + [StepMetadata, RootMetadata], + ) + ) + def test_update_with_prior_data( + self, + blocking_write: bool, + metadata_cls: StepMetadata | RootMetadata, + ): + metadata = metadata_cls(format=SAMPLE_FORMAT) self.write_metadata_store(blocking_write).write( - checkpoint_path=self.directory, checkpoint_metadata=metadata + file_path=metadata.file_path(self.directory), + metadata=self.serialize_metadata(metadata), ) self.write_metadata_store(blocking_write).update( - checkpoint_path=self.directory, - commit_timestamp_nsecs=2, + file_path=metadata.file_path(self.directory), + custom={'a': 1}, ) self.write_metadata_store(blocking_write).wait_until_finished() + metadata_dict = self.write_metadata_store(blocking_write).read( + file_path=metadata.file_path(self.directory) + ) self.assertEqual( - self.write_metadata_store(blocking_write).read( - checkpoint_path=self.directory - ), - checkpoint.StepMetadata( - init_timestamp_nsecs=1, - commit_timestamp_nsecs=2, + self.create_metadata(metadata_cls, metadata_dict), + metadata_cls( + format=SAMPLE_FORMAT, + custom={'a': 1}, ), ) - @parameterized.parameters(True, False) - def test_update_with_unknown_kwargs(self, blocking_write: bool): - with self.assertRaisesRegex( - TypeError, "got an unexpected keyword argument 'blah'" - ): - self.write_metadata_store(blocking_write).update( - checkpoint_path=self.directory, - init_timestamp_nsecs=1, - blah=2, + @parameterized.parameters( + itertools.product( + [True, False], + [StepMetadata, RootMetadata], ) + ) + def test_update_with_unknown_kwargs( + self, blocking_write: bool, metadata_cls: StepMetadata | RootMetadata + ): + metadata = metadata_cls() + self.write_metadata_store(blocking_write).write( + file_path=metadata.file_path(self.directory), + metadata=self.serialize_metadata(metadata), + ) + self.write_metadata_store(blocking_write).update( + file_path=metadata.file_path(self.directory), + format=SAMPLE_FORMAT, + blah=2, + ) - @parameterized.parameters(True, False) - def test_write_with_read_only_store_is_no_op(self, blocking_write: bool): - self.assertIsNone( - self.read_metadata_store(blocking_write).read( - checkpoint_path=self.directory - ) + self.write_metadata_store(blocking_write).wait_until_finished() + + metadata_dict = self.write_metadata_store(blocking_write).read( + file_path=metadata.file_path(self.directory) ) + self.assertEqual( + self.create_metadata(metadata_cls, metadata_dict), + metadata_cls( + format=SAMPLE_FORMAT, + ), + ) + + @parameterized.parameters( + itertools.product( + [True, False], + [StepMetadata, RootMetadata], + ) + ) + def test_write_with_read_only_store_is_no_op( + self, + blocking_write: bool, + metadata_cls: StepMetadata | RootMetadata, + ): + metadata = metadata_cls() self.read_metadata_store(blocking_write).write( - checkpoint_path=self.directory, - checkpoint_metadata=checkpoint.StepMetadata(), + file_path=metadata.file_path(self.directory), + metadata=self.serialize_metadata(metadata), ) - self.assertIsNone( - self.read_metadata_store(blocking_write).read( - checkpoint_path=self.directory - ) + metadata_dict = self.read_metadata_store(blocking_write).read( + file_path=metadata.file_path(self.directory) ) - self.assertIsNone( - self.write_metadata_store(blocking_write).read( - checkpoint_path=self.directory - ) + self.assertIsNone(metadata_dict) + + metadata_dict = self.write_metadata_store(blocking_write).read( + file_path=metadata.file_path(self.directory) ) + self.assertIsNone(metadata_dict) + + @parameterized.parameters(StepMetadata, RootMetadata) + def test_non_blocking_write_request_enables_writes( + self, metadata_cls: StepMetadata | RootMetadata, + ): + metadata = self.get_metadata(metadata_cls) - def test_non_blocking_write_request_enables_writes(self): # setup some data with blocking store. self.write_metadata_store(blocking_write=True).write( - checkpoint_path=self.directory, - checkpoint_metadata=checkpoint.StepMetadata(init_timestamp_nsecs=1), + file_path=metadata.file_path(self.directory), + metadata=self.serialize_metadata(metadata), ) + metadata_dict = self.read_metadata_store(blocking_write=True).read( + file_path=metadata.file_path(self.directory) + ) self.assertEqual( - self.read_metadata_store(blocking_write=False).read( - checkpoint_path=self.directory - ), - checkpoint.StepMetadata(init_timestamp_nsecs=1), + self.create_metadata(metadata_cls, metadata_dict), + self.get_metadata(metadata_cls), ) # write validations + metadata_dict = self.serialize_metadata( + self.get_metadata(metadata_cls, custom={'a': 2}) + ) self.write_metadata_store(blocking_write=False).write( - checkpoint_path=self.directory, - checkpoint_metadata=checkpoint.StepMetadata( - init_timestamp_nsecs=2, commit_timestamp_nsecs=3 - ), + file_path=metadata.file_path(self.directory), + metadata=metadata_dict, ) + self.write_metadata_store(blocking_write=False).wait_until_finished() + + metadata_dict = self.read_metadata_store(blocking_write=False).read( + file_path=metadata.file_path(self.directory) + ) self.assertEqual( - self.read_metadata_store(blocking_write=False).read( - checkpoint_path=self.directory - ), - checkpoint.StepMetadata( - init_timestamp_nsecs=2, commit_timestamp_nsecs=3 - ), + self.create_metadata(metadata_cls, metadata_dict), + self.get_metadata(metadata_cls, custom={'a': 2}), + ) + metadata_dict = self.write_metadata_store(blocking_write=False).read( + file_path=metadata.file_path(self.directory) ) self.assertEqual( - self.write_metadata_store(blocking_write=False).read( - checkpoint_path=self.directory - ), - checkpoint.StepMetadata( - init_timestamp_nsecs=2, commit_timestamp_nsecs=3 - ), + self.create_metadata(metadata_cls, metadata_dict), + self.get_metadata(metadata_cls, custom={'a': 2}), ) # update validations self.write_metadata_store(blocking_write=False).update( - checkpoint_path=self.directory, commit_timestamp_nsecs=7 + file_path=metadata.file_path(self.directory), + custom={'a': 3}, ) self.write_metadata_store(blocking_write=False).wait_until_finished() + metadata_dict = self.read_metadata_store(blocking_write=False).read( + file_path=metadata.file_path(self.directory) + ) self.assertEqual( - self.read_metadata_store(blocking_write=False).read( - checkpoint_path=self.directory - ), - checkpoint.StepMetadata( - init_timestamp_nsecs=2, commit_timestamp_nsecs=7 - ), + self.create_metadata(metadata_cls, metadata_dict), + self.get_metadata(metadata_cls, custom={'a': 3}), + ) + metadata_dict = self.write_metadata_store(blocking_write=False).read( + file_path=metadata.file_path(self.directory) ) self.assertEqual( - self.write_metadata_store(blocking_write=False).read( - checkpoint_path=self.directory - ), - checkpoint.StepMetadata( - init_timestamp_nsecs=2, commit_timestamp_nsecs=7 - ), + self.create_metadata(metadata_cls, metadata_dict), + self.get_metadata(metadata_cls, custom={'a': 3}), ) @parameterized.parameters(True, False) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/root_metadata_serialization.py b/checkpoint/orbax/checkpoint/_src/metadata/root_metadata_serialization.py new file mode 100644 index 00000000..cfc8ccc5 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/metadata/root_metadata_serialization.py @@ -0,0 +1,67 @@ +# Copyright 2024 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal IO utilities for metadata of a checkpoint at root level.""" + +from absl import logging +from orbax.checkpoint._src.metadata import checkpoint + +MetadataDict = checkpoint.MetadataDict +RootMetadata = checkpoint.RootMetadata + + +def serialize(metadata: RootMetadata) -> MetadataDict: + """Serializes `metadata` to a dictionary.""" + return { + 'format': metadata.format, + 'custom': metadata.custom, + } + + +def deserialize(metadata_dict: MetadataDict) -> RootMetadata | None: + """Deserializes `metadata_dict` to `RootMetadata`.""" + validated_metadata_dict = {} + + if metadata_dict.get('format') is not None: + if not isinstance(metadata_dict['format'], str): + raise ValueError( + 'RootMetadata format must be a string, got' + f' {type(metadata_dict["format"])}.' + ) + validated_metadata_dict['format'] = metadata_dict['format'] + else: + validated_metadata_dict['format'] = None + + if metadata_dict.get('custom') is not None: + if not isinstance(metadata_dict['custom'], dict): + raise ValueError( + 'RootMetadata custom must be a dictionary, got' + f' {type(metadata_dict["custom"])}.' + ) + if not all(isinstance(k, str) for k in metadata_dict['custom'].keys()): + raise ValueError( + 'RootMetadata custom keys must be strings, got' + f' {type(metadata_dict["custom"].keys())}.' + ) + validated_metadata_dict['custom'] = metadata_dict['custom'] + else: + validated_metadata_dict['custom'] = {} + + for k in metadata_dict: + if k not in validated_metadata_dict: + logging.warning( + 'Provided metadata contains unknown key %s, ignoring.', k + ) + + return RootMetadata(**validated_metadata_dict) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/step_metadata_serialization.py b/checkpoint/orbax/checkpoint/_src/metadata/step_metadata_serialization.py new file mode 100644 index 00000000..320f1002 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/metadata/step_metadata_serialization.py @@ -0,0 +1,224 @@ +# Copyright 2024 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal IO utilities for metadata of a checkpoint at step level.""" + +import dataclasses +from typing import Any + +from absl import logging +from orbax.checkpoint._src.logging import step_statistics +from orbax.checkpoint._src.metadata import checkpoint + +MetadataDict = checkpoint.MetadataDict +StepMetadata = checkpoint.StepMetadata +ItemMetadata = checkpoint.ItemMetadata +StepStatistics = step_statistics.SaveStepStatistics + + +def serialize(metadata: StepMetadata) -> MetadataDict: + """Serializes `metadata` to a dictionary.""" + + # Part of the StepMetadata api for user convenience, but not saved to disk. + if metadata.item_metadata is not None: + just_item_names = {k: None for k in metadata.item_metadata.keys()} + else: + just_item_names = None + + # Save only float performance metrics. + performance_metrics = metadata.performance_metrics + float_metrics = { + metric: val + for metric, val in dataclasses.asdict(performance_metrics).items() + if isinstance(val, float) + } + + return { + 'format': metadata.format, + 'item_handlers': metadata.item_handlers, + 'item_metadata': just_item_names, + 'metrics': metadata.metrics, + 'performance_metrics': float_metrics, + 'init_timestamp_nsecs': metadata.init_timestamp_nsecs, + 'commit_timestamp_nsecs': metadata.commit_timestamp_nsecs, + 'custom': metadata.custom, + } + + +def deserialize( + metadata_dict: MetadataDict, + item_metadata: ItemMetadata | None = None, + metrics: dict[str, Any] | None = None, +) -> StepMetadata | None: + """Deserializes `metadata_dict` and other kwargs to `StepMetadata`.""" + validated_metadata_dict = {} + + if metadata_dict.get('format') is not None: + if not isinstance(metadata_dict['format'], str): + raise ValueError( + 'StepMetadata format must be a string, got' + f' {type(metadata_dict["format"])}.' + ) + validated_metadata_dict['format'] = metadata_dict['format'] + else: + validated_metadata_dict['format'] = None + + if metadata_dict.get('item_handlers') is not None: + if not isinstance(metadata_dict['item_handlers'], dict): + raise ValueError( + 'StepMetadata item_handlers must be a dictionary, got' + f' {type(metadata_dict["item_handlers"])}.' + ) + for k, v in metadata_dict['item_handlers'].items(): + if not isinstance(k, str): + raise ValueError( + f'StepMetadata item_handlers keys must be strings, got {type(k)}.' + ) + if not isinstance(v, str): + raise ValueError( + f'StepMetadata item_handlers values must be strings, got {type(v)}.' + ) + validated_metadata_dict['item_handlers'] = metadata_dict['item_handlers'] + else: + validated_metadata_dict['item_handlers'] = {} + + if metadata_dict.get('item_metadata') is not None: + if not isinstance(metadata_dict['item_metadata'], dict): + raise ValueError( + 'StepMetadata item_metadata must be a dictionary, got' + f' {type(metadata_dict["item_metadata"])}.' + ) + for k, v in metadata_dict['item_metadata'].items(): + if not isinstance(k, str): + raise ValueError( + f'StepMetadata item_metadata keys must be strings, got {type(k)}.' + ) + if v is not None: + raise ValueError( + f'StepMetadata item_metadata values must be None, got {type(v)}.' + ) + validated_metadata_dict['item_metadata'] = metadata_dict['item_metadata'] + else: + validated_metadata_dict['item_metadata'] = None + if item_metadata is not None: + if validated_metadata_dict['item_metadata'] is None: + validated_metadata_dict['item_metadata'] = {} + if not isinstance(item_metadata, ItemMetadata): + raise ValueError( + 'StepMetadata item_metadata must be of type ItemMetadata, got' + f' {type(item_metadata)}.' + ) + for k, v in item_metadata.items(): + if not isinstance(k, str): + raise ValueError( + f'StepMetadata item_metadata keys must be strings, got {type(k)}.' + ) + validated_metadata_dict['item_metadata'][k] = v + + if metadata_dict.get('metrics') is not None: + if not isinstance(metadata_dict['metrics'], dict): + raise ValueError( + 'StepMetadata metrics must be a dictionary, got' + f' {type(metadata_dict["metrics"])}.' + ) + for k in metadata_dict['metrics']: + if not isinstance(k, str): + raise ValueError( + f'StepMetadata metrics keys must be strings, got {type(k)}.' + ) + validated_metadata_dict['metrics'] = metadata_dict['metrics'] + else: + validated_metadata_dict['metrics'] = {} + if metrics is not None: + if not isinstance(metrics, dict): + raise ValueError( + 'StepMetadata metrics must be a dictionary, got' + f' {type(metrics)}.' + ) + for k, v in metrics.items(): + if not isinstance(k, str): + raise ValueError( + f'StepMetadata metrics keys must be strings, got {type(k)}.' + ) + validated_metadata_dict['metrics'][k] = v + + if metadata_dict.get('performance_metrics') is not None: + if not isinstance(metadata_dict['performance_metrics'], dict): + raise ValueError( + 'StepMetadata performance_metrics must be a dictionary, got' + f' {type(metadata_dict["performance_metrics"])}.' + ) + for k, v in metadata_dict['performance_metrics'].items(): + if not isinstance(k, str): + raise ValueError( + f'StepMetadata performance_metrics keys must be strings, got' + f' {type(k)}.' + ) + if not isinstance(v, float): + raise ValueError( + f'StepMetadata performance_metrics values must be floats, got' + f' {type(v)}.' + ) + validated_metadata_dict['performance_metrics'] = StepStatistics( + **metadata_dict['performance_metrics'] + ) + else: + validated_metadata_dict['performance_metrics'] = StepStatistics() + + if metadata_dict.get('init_timestamp_nsecs') is not None: + if not isinstance(metadata_dict['init_timestamp_nsecs'], int): + raise ValueError( + 'StepMetadata init_timestamp_nsecs must be an integer, got' + f' {type(metadata_dict["init_timestamp_nsecs"])}.' + ) + validated_metadata_dict['init_timestamp_nsecs'] = ( + metadata_dict['init_timestamp_nsecs'] + ) + else: + validated_metadata_dict['init_timestamp_nsecs'] = None + + if metadata_dict.get('commit_timestamp_nsecs') is not None: + if not isinstance(metadata_dict['commit_timestamp_nsecs'], int): + raise ValueError( + 'StepMetadata commit_timestamp_nsecs must be an integer, got' + f' {type(metadata_dict["commit_timestamp_nsecs"])}.' + ) + validated_metadata_dict['commit_timestamp_nsecs'] = ( + metadata_dict['commit_timestamp_nsecs'] + ) + else: + validated_metadata_dict['commit_timestamp_nsecs'] = None + + if metadata_dict.get('custom') is not None: + if not isinstance(metadata_dict['custom'], dict): + raise ValueError( + 'StepMetadata custom must be a dictionary, got' + f' {type(metadata_dict["custom"])}.' + ) + for k in metadata_dict['custom']: + if not isinstance(k, str): + raise ValueError( + f'StepMetadata custom keys must be strings, got {type(k)}.' + ) + validated_metadata_dict['custom'] = metadata_dict['custom'] + else: + validated_metadata_dict['custom'] = {} + + for k in metadata_dict: + if k not in validated_metadata_dict: + logging.warning( + 'Provided metadata contains unknown key %s, ignoring.', k + ) + + return StepMetadata(**validated_metadata_dict) diff --git a/checkpoint/orbax/checkpoint/async_checkpointer.py b/checkpoint/orbax/checkpoint/async_checkpointer.py index 97cb16c7..a82059f1 100644 --- a/checkpoint/orbax/checkpoint/async_checkpointer.py +++ b/checkpoint/orbax/checkpoint/async_checkpointer.py @@ -281,9 +281,7 @@ def __init__( async_options: options_lib.AsyncOptions = options_lib.AsyncOptions(), multiprocessing_options: options_lib.MultiprocessingOptions = options_lib.MultiprocessingOptions(), file_options: options_lib.FileOptions = options_lib.FileOptions(), - checkpoint_metadata_store: Optional[ - checkpoint.CheckpointMetadataStore - ] = None, + metadata_store: Optional[checkpoint.MetadataStore] = None, temporary_path_class: Optional[Type[atomicity.TemporaryPath]] = None, ): jax.monitoring.record_event('/jax/orbax/async_checkpointer/init') @@ -310,9 +308,8 @@ def __init__( ) self._barrier_sync_key_prefix = barrier_sync_key_prefix self._file_options = file_options - self._checkpoint_metadata_store = ( - checkpoint_metadata_store - or checkpoint.checkpoint_metadata_store(enable_write=True) + self._metadata_store = ( + metadata_store or checkpoint.metadata_store(enable_write=True) ) self._temporary_path_class = temporary_path_class timeout_secs = timeout_secs or async_options.timeout_secs @@ -445,18 +442,18 @@ def restore(self, directory: epath.PathLike, *args, **kwargs) -> Any: def check_for_errors(self): """Surfaces any errors from the background commit operations.""" self._async_manager.check_for_errors() - self._checkpoint_metadata_store.wait_until_finished() + self._metadata_store.wait_until_finished() def wait_until_finished(self): """Waits for any outstanding operations to finish.""" self._async_manager.wait_until_finished() - self._checkpoint_metadata_store.wait_until_finished() + self._metadata_store.wait_until_finished() def close(self): """Waits to finish any outstanding operations before closing.""" self.wait_until_finished() super().close() - self._checkpoint_metadata_store.close() + self._metadata_store.close() @property def handler(self) -> async_checkpoint_handler.AsyncCheckpointHandler: diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index b1c15fe2..8e8626cb 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -82,7 +82,6 @@ DEFAULT_ITEM_NAME = 'default' DESCRIPTOR_ITEM_NAME = 'descriptor' METRIC_ITEM_NAME = 'metrics' -METADATA_ITEM_NAME = 'metadata' RESERVED_ITEM_NAMES = [DESCRIPTOR_ITEM_NAME, METRIC_ITEM_NAME] _INIT_TIME = datetime.datetime.now(tz=datetime.timezone.utc) @@ -593,14 +592,12 @@ def __init__( ) # For async_checkpointer. - self._non_blocking_checkpoint_metadata_store = ( - checkpoint.checkpoint_metadata_store(enable_write=True) + self._non_blocking_metadata_store = ( + checkpoint.metadata_store(enable_write=True) ) # For metadata checkpointer and regular checkpointer. - self._blocking_checkpoint_metadata_store = ( - checkpoint.checkpoint_metadata_store( - enable_write=True, blocking_write=True - ) + self._blocking_metadata_store = ( + checkpoint.metadata_store(enable_write=True, blocking_write=True) ) if checkpointers is not None: @@ -690,7 +687,7 @@ def __init__( ), multiprocessing_options=self._options.multiprocessing_options, file_options=self._options.file_options, - checkpoint_metadata_store=self._blocking_checkpoint_metadata_store, + metadata_store=self._blocking_metadata_store, temporary_path_class=self._options.temporary_path_class, ) if self._options.read_only and not self._metadata_path().exists(): @@ -760,7 +757,7 @@ def _configure_checkpointer_common( multiprocessing_options=options.multiprocessing_options, async_options=options.async_options or AsyncOptions(), file_options=options.file_options, - checkpoint_metadata_store=self._non_blocking_checkpoint_metadata_store, + metadata_store=self._non_blocking_metadata_store, temporary_path_class=options.temporary_path_class, ) else: @@ -768,7 +765,7 @@ def _configure_checkpointer_common( handler, multiprocessing_options=options.multiprocessing_options, file_options=options.file_options, - checkpoint_metadata_store=self._blocking_checkpoint_metadata_store, + metadata_store=self._blocking_metadata_store, temporary_path_class=options.temporary_path_class, ) @@ -1480,7 +1477,7 @@ def _add_checkpoint_info(self, step: int, metrics: Optional[PyTree]): ) def _metadata_path(self) -> epath.Path: - return self.directory / METADATA_ITEM_NAME + return checkpoint.RootMetadata.file_path(self.directory).parent def _save_metadata(self, metadata: Mapping[str, Any]): """Saves CheckpointManager level metadata, skips if already present.""" @@ -1495,6 +1492,12 @@ def _save_metadata(self, metadata: Mapping[str, Any]): ), processes=self._multiprocessing_options.active_processes, ) + logging.info( + '[process=%s][thread=%s][_save_metadata] path_exists=%s', + multihost.process_index(), + threading.current_thread().name, + path_exists, + ) if not path_exists: # May have been created by a previous run. self._metadata_checkpointer.save(path, metadata) @@ -1765,7 +1768,7 @@ def _finalize_checkpoint(self, step: int): def _finalize(self, step: int, steps_to_remove: List[int]): """Finalizes individual items and starts garbage collection.""" process_index = multihost.process_index() - self._non_blocking_checkpoint_metadata_store.wait_until_finished() + self._non_blocking_metadata_store.wait_until_finished() self._wait_for_checkpointers() # If an error is encountered while waiting for commit futures to complete, # we will not proceed past this point. @@ -1804,8 +1807,8 @@ def close(self): self.wait_until_finished() self._checkpointer.close() # Call after checkpointer.close(). - self._non_blocking_checkpoint_metadata_store.close() - self._blocking_checkpoint_metadata_store.close() + self._non_blocking_metadata_store.close() + self._blocking_metadata_store.close() self._checkpoint_deleter.close() def __contextmanager__( diff --git a/checkpoint/orbax/checkpoint/checkpointer.py b/checkpoint/orbax/checkpoint/checkpointer.py index 0f9a719e..c4c72be9 100644 --- a/checkpoint/orbax/checkpoint/checkpointer.py +++ b/checkpoint/orbax/checkpoint/checkpointer.py @@ -105,9 +105,7 @@ def __init__( *, multiprocessing_options: options_lib.MultiprocessingOptions = options_lib.MultiprocessingOptions(), file_options: options_lib.FileOptions = options_lib.FileOptions(), - checkpoint_metadata_store: Optional[ - checkpoint.CheckpointMetadataStore - ] = None, + metadata_store: Optional[checkpoint.MetadataStore] = None, temporary_path_class: Optional[Type[atomicity.TemporaryPath]] = None, ): if not checkpoint_args.has_registered_args(handler): @@ -126,13 +124,11 @@ def __init__( self._temporary_path_class = temporary_path_class # If not provided then use checkpoint_metadata_store with blocking writes. - self._checkpoint_metadata_store = ( - checkpoint_metadata_store - or checkpoint.checkpoint_metadata_store( - enable_write=True, blocking_write=True - ) + self._metadata_store = ( + metadata_store + or checkpoint.metadata_store(enable_write=True, blocking_write=True) ) - if not self._checkpoint_metadata_store.is_blocking_writer(): + if not self._metadata_store.is_blocking_writer(): raise ValueError('Checkpoint metadata store must be blocking writer.') jax.monitoring.record_event('/jax/orbax/checkpointer/init') @@ -151,7 +147,7 @@ async def create_temporary_path( ) tmpdir = temporary_path_class.from_final( directory, - checkpoint_metadata_store=self._checkpoint_metadata_store, + metadata_store=self._metadata_store, multiprocessing_options=multiprocessing_options, file_options=self._file_options, ) @@ -256,7 +252,7 @@ def metadata(self, directory: epath.PathLike) -> Optional[Any]: def close(self): """Closes the underlying CheckpointHandler.""" self._handler.close() - self._checkpoint_metadata_store.close() + self._metadata_store.close() @property def handler(self) -> checkpoint_handler.CheckpointHandler: diff --git a/checkpoint/orbax/checkpoint/metadata/__init__.py b/checkpoint/orbax/checkpoint/metadata/__init__.py index e9302c68..0d0361eb 100644 --- a/checkpoint/orbax/checkpoint/metadata/__init__.py +++ b/checkpoint/orbax/checkpoint/metadata/__init__.py @@ -17,8 +17,8 @@ # pylint: disable=g-importing-member, g-bad-import-order from orbax.checkpoint._src.metadata.checkpoint import StepMetadata -from orbax.checkpoint._src.metadata.checkpoint import CheckpointMetadataStore -from orbax.checkpoint._src.metadata.checkpoint import checkpoint_metadata_store +from orbax.checkpoint._src.metadata.checkpoint import MetadataStore +from orbax.checkpoint._src.metadata.checkpoint import metadata_store from orbax.checkpoint._src.metadata.sharding import ShardingMetadata from orbax.checkpoint._src.metadata.sharding import NamedShardingMetadata diff --git a/checkpoint/orbax/checkpoint/path/atomicity.py b/checkpoint/orbax/checkpoint/path/atomicity.py index 678ee2b3..77ffff26 100644 --- a/checkpoint/orbax/checkpoint/path/atomicity.py +++ b/checkpoint/orbax/checkpoint/path/atomicity.py @@ -62,6 +62,7 @@ import jax from orbax.checkpoint import options as options_lib from orbax.checkpoint._src.metadata import checkpoint as checkpoint_metadata +from orbax.checkpoint._src.metadata import step_metadata_serialization from orbax.checkpoint._src.multihost import counters from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import utils @@ -88,9 +89,7 @@ def from_final( cls, final_path: epath.Path, *, - checkpoint_metadata_store: Optional[ - checkpoint_metadata.CheckpointMetadataStore - ] = None, + metadata_store: Optional[checkpoint_metadata.MetadataStore] = None, file_options: Optional[options_lib.FileOptions] = None, multiprocessing_options: Optional[ options_lib.MultiprocessingOptions @@ -137,9 +136,7 @@ async def _create_tmp_directory( *, primary_host: Optional[int] = 0, path_permission_mode: int = step_lib.WORLD_READABLE_MODE, - checkpoint_metadata_store: Optional[ - checkpoint_metadata.CheckpointMetadataStore - ] = None, + metadata_store: Optional[checkpoint_metadata.MetadataStore] = None, ) -> epath.Path: """Creates a non-deterministic tmp directory for saving for given `final_dir`. @@ -152,8 +149,8 @@ async def _create_tmp_directory( 0o750. Please check https://github.com/google/etils/blob/main/etils/epath/backend.py if your path is supported. - checkpoint_metadata_store: optional `CheckpointMetadataStore` instance. If - present then it is used to create `StepMetadata` with current timestamp. + metadata_store: optional `MetadataStore` instance. If present then it is + used to create `StepMetadata` with current timestamp. Returns: The tmp directory. @@ -182,12 +179,13 @@ async def _create_tmp_directory( exist_ok=False, mode=path_permission_mode, ) - if checkpoint_metadata_store is not None: - checkpoint_metadata_store.write( - checkpoint_path=tmp_dir, - checkpoint_metadata=checkpoint_metadata.StepMetadata( - init_timestamp_nsecs=time.time_ns() - ), + if metadata_store is not None: + metadata = checkpoint_metadata.StepMetadata( + init_timestamp_nsecs=time.time_ns(), + ) + metadata_store.write( + file_path=metadata.file_path(tmp_dir), + metadata=step_metadata_serialization.serialize(metadata), ) return tmp_dir @@ -217,9 +215,7 @@ def __init__( temporary_path: epath.Path, final_path: epath.Path, *, - checkpoint_metadata_store: Optional[ - checkpoint_metadata.CheckpointMetadataStore - ] = None, + metadata_store: Optional[checkpoint_metadata.MetadataStore] = None, file_options: Optional[options_lib.FileOptions] = None, multiprocessing_options: Optional[ options_lib.MultiprocessingOptions @@ -232,7 +228,7 @@ def __init__( multiprocessing_options or options_lib.MultiprocessingOptions() ) file_options = file_options or options_lib.FileOptions() - self._checkpoint_metadata_store = checkpoint_metadata_store + self._metadata_store = metadata_store self._primary_host = multiprocessing_options.primary_host self._active_processes = multiprocessing_options.active_processes self._barrier_sync_key_prefix = ( @@ -245,9 +241,7 @@ def from_final( cls, final_path: epath.Path, *, - checkpoint_metadata_store: Optional[ - checkpoint_metadata.CheckpointMetadataStore - ] = None, + metadata_store: Optional[checkpoint_metadata.MetadataStore] = None, file_options: Optional[options_lib.FileOptions] = None, multiprocessing_options: Optional[ options_lib.MultiprocessingOptions @@ -256,7 +250,7 @@ def from_final( return cls( _get_tmp_directory(final_path), final_path, - checkpoint_metadata_store=checkpoint_metadata_store, + metadata_store=metadata_store, file_options=file_options, multiprocessing_options=multiprocessing_options, ) @@ -306,7 +300,7 @@ async def create( self._tmp_path, primary_host=self._primary_host, path_permission_mode=mode, - checkpoint_metadata_store=self._checkpoint_metadata_store, + metadata_store=self._metadata_store, ) def finalize(self): @@ -315,13 +309,13 @@ def finalize(self): Updates checkpoint metadata with commit_timestamp_nsecs. """ logging.info('Renaming %s to %s', self._tmp_path, self._final_path) - if self._checkpoint_metadata_store: - self._checkpoint_metadata_store.wait_until_finished() - self._checkpoint_metadata_store.update( - checkpoint_path=self._tmp_path, + if self._metadata_store: + self._metadata_store.wait_until_finished() + self._metadata_store.update( + file_path=checkpoint_metadata.StepMetadata.file_path(self._tmp_path), commit_timestamp_nsecs=time.time_ns(), ) - self._checkpoint_metadata_store.wait_until_finished() + self._metadata_store.wait_until_finished() self._tmp_path.rename(self._final_path) def __repr__(self) -> str: @@ -340,9 +334,7 @@ def __init__( temporary_path: epath.Path, final_path: epath.Path, *, - checkpoint_metadata_store: Optional[ - checkpoint_metadata.CheckpointMetadataStore - ] = None, + metadata_store: Optional[checkpoint_metadata.MetadataStore] = None, file_options: Optional[options_lib.FileOptions] = None, multiprocessing_options: Optional[ options_lib.MultiprocessingOptions @@ -355,7 +347,7 @@ def __init__( multiprocessing_options or options_lib.MultiprocessingOptions() ) file_options = file_options or options_lib.FileOptions() - self._checkpoint_metadata_store = checkpoint_metadata_store + self._metadata_store = metadata_store self._primary_host = multiprocessing_options.primary_host self._active_processes = multiprocessing_options.active_processes self._barrier_sync_key_prefix = ( @@ -368,9 +360,7 @@ def from_final( cls, final_path: epath.Path, *, - checkpoint_metadata_store: Optional[ - checkpoint_metadata.CheckpointMetadataStore - ] = None, + metadata_store: Optional[checkpoint_metadata.MetadataStore] = None, file_options: Optional[options_lib.FileOptions] = None, multiprocessing_options: Optional[ options_lib.MultiprocessingOptions @@ -379,7 +369,7 @@ def from_final( return cls( final_path, final_path, - checkpoint_metadata_store=checkpoint_metadata_store, + metadata_store=metadata_store, file_options=file_options, multiprocessing_options=multiprocessing_options, ) @@ -427,7 +417,7 @@ async def create( self._tmp_path, primary_host=self._primary_host, path_permission_mode=mode, - checkpoint_metadata_store=self._checkpoint_metadata_store, + metadata_store=self._metadata_store, ) def finalize(self): @@ -436,13 +426,13 @@ def finalize(self): Updates checkpoint metadata with commit_timestamp_nsecs. """ logging.info('Finalizing %s', self._tmp_path) - if self._checkpoint_metadata_store: - self._checkpoint_metadata_store.wait_until_finished() - self._checkpoint_metadata_store.update( - checkpoint_path=self._tmp_path, + if self._metadata_store: + self._metadata_store.wait_until_finished() + self._metadata_store.update( + file_path=checkpoint_metadata.StepMetadata.file_path(self._tmp_path), commit_timestamp_nsecs=time.time_ns(), ) - self._checkpoint_metadata_store.wait_until_finished() + self._metadata_store.wait_until_finished() commit_success_file = self._final_path / step_lib._COMMIT_SUCCESS_FILE # pylint: disable=protected-access commit_success_file.write_text( f'Checkpoint commit was successful to {self._final_path}' diff --git a/checkpoint/orbax/checkpoint/path/format_utils.py b/checkpoint/orbax/checkpoint/path/format_utils.py index 96f125c4..0135fd8c 100644 --- a/checkpoint/orbax/checkpoint/path/format_utils.py +++ b/checkpoint/orbax/checkpoint/path/format_utils.py @@ -93,9 +93,7 @@ def is_orbax_checkpoint(path: epath.PathLike) -> bool: raise FileNotFoundError(f'Checkpoint path {path} does not exist.') if not path.is_dir(): raise NotADirectoryError(f'Checkpoint path {path} is not a directory.') - metadata_store = checkpoint_metadata.checkpoint_metadata_store( - enable_write=False - ) + metadata_store = checkpoint_metadata.metadata_store(enable_write=False) # Path points to a single step checkpoint with valid metadata. if metadata_store.read(path) is not None: return True diff --git a/checkpoint/orbax/checkpoint/path/step.py b/checkpoint/orbax/checkpoint/path/step.py index 384ff844..47a23493 100644 --- a/checkpoint/orbax/checkpoint/path/step.py +++ b/checkpoint/orbax/checkpoint/path/step.py @@ -28,6 +28,7 @@ import jax import numpy as np from orbax.checkpoint._src.metadata import checkpoint +from orbax.checkpoint._src.metadata import step_metadata_serialization from orbax.checkpoint._src.multihost import multihost @@ -61,8 +62,12 @@ class Metadata: @functools.cached_property def _checkpoint_metadata(self) -> Optional[checkpoint.StepMetadata]: """Returns checkpoint metadata of this step.""" - store = checkpoint.checkpoint_metadata_store(enable_write=False) - return store.read(self.path) + metadata_dict = checkpoint.metadata_store(enable_write=False).read( + file_path=checkpoint.StepMetadata.file_path(self.path) + ) + if metadata_dict is None: + return None + return step_metadata_serialization.deserialize(metadata_dict) @property def init_timestamp_nsecs(self) -> Optional[int]: diff --git a/checkpoint/orbax/checkpoint/path/step_test.py b/checkpoint/orbax/checkpoint/path/step_test.py index fc614da3..cebcdc8c 100644 --- a/checkpoint/orbax/checkpoint/path/step_test.py +++ b/checkpoint/orbax/checkpoint/path/step_test.py @@ -21,6 +21,7 @@ from etils import epath from orbax.checkpoint import test_utils from orbax.checkpoint._src.metadata import checkpoint +from orbax.checkpoint._src.metadata import step_metadata_serialization from orbax.checkpoint.path import atomicity from orbax.checkpoint.path import step as step_lib @@ -377,13 +378,13 @@ def setUp(self): def test_checkpoint_metadata_based_fields(self): step_path = self.directory / 'step_1' step_path.mkdir(parents=True, exist_ok=True) - checkpoint.checkpoint_metadata_store( - enable_write=True, blocking_write=True - ).write( - step_path, - checkpoint.StepMetadata( - init_timestamp_nsecs=1, commit_timestamp_nsecs=2 - ), + metadata = checkpoint.StepMetadata( + init_timestamp_nsecs=1, + commit_timestamp_nsecs=2, + ) + checkpoint.metadata_store(enable_write=True, blocking_write=True).write( + file_path=metadata.file_path(step_path), + metadata=step_metadata_serialization.serialize(metadata), ) metadata = step_lib.Metadata(step=1, path=step_path) diff --git a/checkpoint/orbax/checkpoint/standard_checkpointer.py b/checkpoint/orbax/checkpoint/standard_checkpointer.py index b3dce2ad..91e1b298 100644 --- a/checkpoint/orbax/checkpoint/standard_checkpointer.py +++ b/checkpoint/orbax/checkpoint/standard_checkpointer.py @@ -64,9 +64,7 @@ def __init__( async_options: options_lib.AsyncOptions = options_lib.AsyncOptions(), multiprocessing_options: options_lib.MultiprocessingOptions = options_lib.MultiprocessingOptions(), file_options: options_lib.FileOptions = options_lib.FileOptions(), - checkpoint_metadata_store: Optional[ - checkpoint.CheckpointMetadataStore - ] = None, + metadata_store: Optional[checkpoint.MetadataStore] = None, temporary_path_class: Optional[Type[atomicity.TemporaryPath]] = None, **kwargs, ): @@ -76,7 +74,7 @@ def __init__( async_options: See superclass documentation. multiprocessing_options: See superclass documentation. file_options: See superclass documentation. - checkpoint_metadata_store: See superclass documentation. + metadata_store: See superclass documentation. temporary_path_class: See superclass documentation. **kwargs: Additional init args passed to StandardCHeckpointHandler. See orbax.checkpoint.standard_checkpoint_handler.StandardCheckpointHandler. @@ -89,7 +87,7 @@ def __init__( async_options=async_options, multiprocessing_options=multiprocessing_options, file_options=file_options, - checkpoint_metadata_store=checkpoint_metadata_store, + metadata_store=metadata_store, temporary_path_class=temporary_path_class, ) diff --git a/checkpoint/orbax/checkpoint/test_utils.py b/checkpoint/orbax/checkpoint/test_utils.py index 5b6753ed..d239c43f 100644 --- a/checkpoint/orbax/checkpoint/test_utils.py +++ b/checkpoint/orbax/checkpoint/test_utils.py @@ -40,6 +40,7 @@ from orbax.checkpoint._src.handlers import async_checkpoint_handler from orbax.checkpoint._src.handlers import pytree_checkpoint_handler from orbax.checkpoint._src.metadata import checkpoint as checkpoint_metadata +from orbax.checkpoint._src.metadata import step_metadata_serialization from orbax.checkpoint._src.multihost import counters from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.multihost import multislice @@ -112,9 +113,7 @@ def create_tmp_directory( active_processes: Optional[Set[int]] = None, barrier_sync_key_prefix: Optional[str] = None, path_permission_mode: int = step_lib.WORLD_READABLE_MODE, - checkpoint_metadata_store: Optional[ - checkpoint_metadata.CheckpointMetadataStore - ] = None, + metadata_store: Optional[checkpoint_metadata.MetadataStore] = None, ) -> epath.Path: """Creates a non-deterministic tmp directory for saving for given `final_dir`. @@ -130,9 +129,8 @@ def create_tmp_directory( 0o750. Please check https://github.com/google/etils/blob/main/etils/epath/backend.py if your path is supported. - checkpoint_metadata_store: optional `CheckpointMetadataStore` instance. If - present then it is used to create `CheckpointMetadata` with current - timestamp. + metadata_store: optional `MetadataStore` instance. If present then it is + used to create `StepMetadata` with current timestamp. Returns: The tmp directory. @@ -167,12 +165,13 @@ def create_tmp_directory( ) logging.info('Creating tmp directory %s', tmp_dir) tmp_dir.mkdir(parents=True, exist_ok=False, mode=path_permission_mode) - if checkpoint_metadata_store is not None: - checkpoint_metadata_store.write( - checkpoint_path=tmp_dir, - checkpoint_metadata=checkpoint_metadata.StepMetadata( - init_timestamp_nsecs=time.time_ns() - ), + if metadata_store is not None: + metadata = checkpoint_metadata.StepMetadata( + init_timestamp_nsecs=time.time_ns(), + ) + metadata_store.write( + file_path=metadata.file_path(tmp_dir), + metadata=step_metadata_serialization.serialize(metadata), ) multihost.sync_global_processes( @@ -619,22 +618,20 @@ def test_foo(self): def ensure_atomic_save( temp_ckpt_dir: epath.Path, final_ckpt_dir: epath.Path, - checkpoint_metadata_store: Optional[ - checkpoint_metadata.CheckpointMetadataStore - ] = None, + metadata_store: Optional[checkpoint_metadata.MetadataStore] = None, ): """Wrapper around TemporaryPath.finalize for testing.""" if temp_ckpt_dir == final_ckpt_dir: atomicity.CommitFileTemporaryPath( temp_ckpt_dir, final_ckpt_dir, - checkpoint_metadata_store=checkpoint_metadata_store, + metadata_store=metadata_store, ).finalize() else: atomicity.AtomicRenameTemporaryPath( temp_ckpt_dir, final_ckpt_dir, - checkpoint_metadata_store=checkpoint_metadata_store, + metadata_store=metadata_store, ).finalize()