Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ttompa committed Dec 18, 2024
1 parent fd57dae commit 5002ff3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 47 deletions.
3 changes: 1 addition & 2 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,11 +607,11 @@ def run(args: argparse.Namespace) -> None:
group["lr"] = args.lr

if args.lbfgs_config:
use_lbfgs = True
max_iter = args.lbfgs_config.get("max_iter", 200)
history_size = args.lbfgs_config.get("history", 240)
batch_mode = args.lbfgs_config.get("batch_mode", False)

logging.info("Switching optimizer to LBFGS")
optimizer = LBFGSNew(model.parameters(),
tolerance_grad=1e-6,
history_size=history_size,
Expand Down Expand Up @@ -643,7 +643,6 @@ def run(args: argparse.Namespace) -> None:
device=device,
swa=swa,
ema=ema,
lbfgs=use_lbfgs,
max_grad_norm=args.clip_grad,
log_errors=args.error_table,
log_wandb=args.wandb,
Expand Down
55 changes: 10 additions & 45 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.utils.data.distributed import DistributedSampler
from torch_ema import ExponentialMovingAverage
from torchmetrics import Metric
from .lbfgsnew import LBFGSNew

from . import torch_geometric
from .checkpoint import CheckpointHandler, CheckpointState
Expand Down Expand Up @@ -151,7 +152,6 @@ def train(
device: torch.device,
log_errors: str,
swa: Optional[SWAContainer] = None,
lbfgs: bool = False,
ema: Optional[ExponentialMovingAverage] = None,
max_grad_norm: Optional[float] = 10.0,
log_wandb: bool = False,
Expand Down Expand Up @@ -228,7 +228,6 @@ def train(
device=device,
distributed_model=distributed_model,
rank=rank,
use_lbfgs=lbfgs
)
if distributed:
torch.distributed.barrier()
Expand Down Expand Up @@ -333,13 +332,11 @@ def train_one_epoch(
device: torch.device,
distributed_model: Optional[DistributedDataParallel] = None,
rank: Optional[int] = 0,
use_lbfgs: bool = False,
) -> None:
model_to_train = model if distributed_model is None else distributed_model
take_step_fn = take_lbfgs_step if use_lbfgs else take_step

for batch in data_loader:
_, opt_metrics = take_step_fn(
_, opt_metrics = take_step(
model=model_to_train,
loss_fn=loss_fn,
batch=batch,
Expand All @@ -364,44 +361,6 @@ def take_step(
output_args: Dict[str, bool],
max_grad_norm: Optional[float],
device: torch.device,
) -> Tuple[float, Dict[str, Any]]:
start_time = time.time()
batch = batch.to(device)
optimizer.zero_grad(set_to_none=True)
batch_dict = batch.to_dict()
output = model(
batch_dict,
training=True,
compute_force=output_args["forces"],
compute_virials=output_args["virials"],
compute_stress=output_args["stress"],
)
loss = loss_fn(pred=output, ref=batch)
loss.backward()
if max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
optimizer.step()

if ema is not None:
ema.update()

loss_dict = {
"loss": to_numpy(loss),
"time": time.time() - start_time,
}

return loss, loss_dict


def take_lbfgs_step(
model: torch.nn.Module,
loss_fn: torch.nn.Module,
batch: torch_geometric.batch.Batch,
optimizer: torch.optim.Optimizer,
ema: Optional[ExponentialMovingAverage],
output_args: Dict[str, bool],
max_grad_norm: Optional[float],
device: torch.device,
) -> Tuple[float, Dict[str, Any]]:
start_time = time.time()
batch = batch.to(device)
Expand All @@ -418,11 +377,17 @@ def closure():
)
loss = loss_fn(pred=output, ref=batch)
loss.backward()
if max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)

return loss

optimizer.step(closure)
loss = closure()
if isinstance(optimizer, LBFGSNew):
optimizer.step(closure)
loss = closure()
else:
loss = closure()
optimizer.step()

if ema is not None:
ema.update()
Expand Down

0 comments on commit 5002ff3

Please sign in to comment.