Skip to content

Commit

Permalink
update params loader
Browse files Browse the repository at this point in the history
  • Loading branch information
thorben-frank committed Aug 22, 2024
1 parent 271541f commit 39abb58
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion mlff/io/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,29 @@
import os

from orbax.checkpoint import PyTreeCheckpointer, Checkpointer, PyTreeCheckpointHandler
from orbax import checkpoint
import pathlib

__STEP_PREFIX__: str = 'ckpt'


def load_params_from_ckpt_dir(ckpt_dir):
loaded_mngr = checkpoint.CheckpointManager(
pathlib.Path(ckpt_dir).resolve(),
item_names=('state',),
item_handlers={'state': checkpoint.StandardCheckpointHandler()},
options=checkpoint.CheckpointManagerOptions(step_prefix="ckpt"),
)

mngr_state = loaded_mngr.restore(
loaded_mngr.latest_step()
)

state = mngr_state.get('state')

return state['valid_params']


def load_state_from_ckpt_dir(ckpt_dir: str):
# mngr = CheckpointManager(ckpt_dir, __CHECKPOINTERS__, options=CheckpointManagerOptions(step_prefix=__STEP_PREFIX__))
# return mngr.restore(n)['state']
Expand All @@ -27,5 +46,5 @@ def load_state_from_ckpt_dir(ckpt_dir: str):
return ckptr.restore(abs_ckpt_dir / f'{__STEP_PREFIX__}_{max_step}/state', item=None)


def load_params_from_ckpt_dir(ckpt_dir: str):
def _load_params_from_ckpt_dir(ckpt_dir: str):
return load_state_from_ckpt_dir(ckpt_dir)['valid_params']

0 comments on commit 39abb58

Please sign in to comment.