diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 61dca320..f8ad8d00 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -607,11 +607,11 @@ def run(args: argparse.Namespace) -> None: group["lr"] = args.lr if args.lbfgs_config: - use_lbfgs = True max_iter = args.lbfgs_config.get("max_iter", 200) history_size = args.lbfgs_config.get("history", 240) batch_mode = args.lbfgs_config.get("batch_mode", False) + logging.info("Switching optimizer to LBFGS") optimizer = LBFGSNew(model.parameters(), tolerance_grad=1e-6, history_size=history_size, @@ -643,7 +643,6 @@ def run(args: argparse.Namespace) -> None: device=device, swa=swa, ema=ema, - lbfgs=use_lbfgs, max_grad_norm=args.clip_grad, log_errors=args.error_table, log_wandb=args.wandb, diff --git a/mace/tools/train.py b/mace/tools/train.py index 98661b61..a8e0dc3a 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -19,6 +19,7 @@ from torch.utils.data.distributed import DistributedSampler from torch_ema import ExponentialMovingAverage from torchmetrics import Metric +from .lbfgsnew import LBFGSNew from . import torch_geometric from .checkpoint import CheckpointHandler, CheckpointState @@ -151,7 +152,6 @@ def train( device: torch.device, log_errors: str, swa: Optional[SWAContainer] = None, - lbfgs: bool = False, ema: Optional[ExponentialMovingAverage] = None, max_grad_norm: Optional[float] = 10.0, log_wandb: bool = False, @@ -228,7 +228,6 @@ def train( device=device, distributed_model=distributed_model, rank=rank, - use_lbfgs=lbfgs ) if distributed: torch.distributed.barrier() @@ -333,13 +332,11 @@ def train_one_epoch( device: torch.device, distributed_model: Optional[DistributedDataParallel] = None, rank: Optional[int] = 0, - use_lbfgs: bool = False, ) -> None: model_to_train = model if distributed_model is None else distributed_model - take_step_fn = take_lbfgs_step if use_lbfgs else take_step for batch in data_loader: - _, opt_metrics = take_step_fn( + _, opt_metrics = take_step( model=model_to_train, loss_fn=loss_fn, batch=batch, @@ -364,44 +361,6 @@ def take_step( output_args: Dict[str, bool], max_grad_norm: Optional[float], device: torch.device, -) -> Tuple[float, Dict[str, Any]]: - start_time = time.time() - batch = batch.to(device) - optimizer.zero_grad(set_to_none=True) - batch_dict = batch.to_dict() - output = model( - batch_dict, - training=True, - compute_force=output_args["forces"], - compute_virials=output_args["virials"], - compute_stress=output_args["stress"], - ) - loss = loss_fn(pred=output, ref=batch) - loss.backward() - if max_grad_norm is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) - optimizer.step() - - if ema is not None: - ema.update() - - loss_dict = { - "loss": to_numpy(loss), - "time": time.time() - start_time, - } - - return loss, loss_dict - - -def take_lbfgs_step( - model: torch.nn.Module, - loss_fn: torch.nn.Module, - batch: torch_geometric.batch.Batch, - optimizer: torch.optim.Optimizer, - ema: Optional[ExponentialMovingAverage], - output_args: Dict[str, bool], - max_grad_norm: Optional[float], - device: torch.device, ) -> Tuple[float, Dict[str, Any]]: start_time = time.time() batch = batch.to(device) @@ -418,11 +377,17 @@ def closure(): ) loss = loss_fn(pred=output, ref=batch) loss.backward() + if max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) return loss - optimizer.step(closure) - loss = closure() + if isinstance(optimizer, LBFGSNew): + optimizer.step(closure) + loss = closure() + else: + loss = closure() + optimizer.step() if ema is not None: ema.update()