Skip to content

Commit

Permalink
fix restarting LBFGS checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
ttompa committed Jan 12, 2025
1 parent 8e26459 commit 0857727
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def run(args: argparse.Namespace) -> None:
)

start_epoch = 0
restart_lbfgs = False
if args.restart_latest:
try:
opt_start_epoch = checkpoint_handler.load_latest(
Expand All @@ -614,11 +615,14 @@ def run(args: argparse.Namespace) -> None:
device=device,
)
except Exception: # pylint: disable=W0703
opt_start_epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=False,
device=device,
)
try:
opt_start_epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=False,
device=device,
)
except Exception: # pylint: disable=W0703
restart_lbfgs = True
if opt_start_epoch is not None:
start_epoch = opt_start_epoch

Expand All @@ -635,6 +639,14 @@ def run(args: argparse.Namespace) -> None:
history_size=200,
max_iter=200,
line_search_fn="strong_wolfe")
if restart_lbfgs:
opt_start_epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=False,
device=device,
)
if opt_start_epoch is not None:
start_epoch = opt_start_epoch

if args.wandb:
setup_wandb(args)
Expand Down

0 comments on commit 0857727

Please sign in to comment.