Skip to content

Add LBFGS optimizer as an option #792

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 55 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
83d15e3
using torch.optim.LBFGS
Dec 4, 2024
7b06eca
add LBFGS step using lbfgsnew.py
Dec 13, 2024
2677ae2
add lbfgs_config argument
ttompa Dec 14, 2024
337a607
add lbfgs as an option for the optimizer argument
ttompa Dec 17, 2024
5945ba1
remove --lbfgs argument
ttompa Dec 17, 2024
5d9ff24
cleanup
ttompa Dec 17, 2024
fd57dae
move lbfgs out of optimizer args, because checkpoint loading conflict…
ttompa Dec 18, 2024
5002ff3
cleanup
ttompa Dec 18, 2024
93b3e41
add CUDA memory logs for debugging
ttompa Dec 18, 2024
34d7e26
test - enable keep_graph for foundation model
ttompa Dec 18, 2024
5eae9d7
disable retain_graph
Dec 19, 2024
9cb5d65
train only readouts with lbfgs
ttompa Dec 19, 2024
76b333e
make LBFGS always use full batch, but calculate gradients from mini b…
ttompa Dec 21, 2024
246655c
fixes for the lbfgs step
ttompa Dec 22, 2024
c684f42
use pytorch's LBFGS optimizer
ttompa Dec 22, 2024
af19ccc
add shampoo optimizer
ttompa Dec 23, 2024
be846c0
initial multi GPU LBFGS test
ttompa Dec 23, 2024
eba347a
small fix
ttompa Dec 24, 2024
2d422a3
cleanup
ttompa Dec 28, 2024
321996d
add more barriers and try to prevent LBFGS steps from running on extr…
ttompa Dec 29, 2024
996b0b0
Revert "add more barriers and try to prevent LBFGS steps from running…
ttompa Dec 29, 2024
da65079
add extra barriers to keep GPUs in sync
ttompa Dec 29, 2024
e134289
Revert "add extra barriers to keep GPUs in sync"
ttompa Dec 29, 2024
3de91bf
add logging for debugging multi GPU LBFGS
ttompa Dec 29, 2024
098d0c8
test a different shampoo implementation
ttompa Dec 29, 2024
3cd1cf5
log on all GPUs
ttompa Dec 29, 2024
84aecdd
print batch size
ttompa Dec 29, 2024
4e1ea2f
evert "test a different shampoo implementation"
ttompa Dec 29, 2024
5fa3a34
tune shampoo params
ttompa Dec 29, 2024
71b86b2
add waiting signal for multi GPU LBFGS
ttompa Dec 29, 2024
4393f6f
Merge branch 'lbfgs-multi-gpu' into shampoo-optimizer
vue1999 Dec 29, 2024
81020d5
Merge pull request #2 from vue1999/shampoo-optimizer
vue1999 Dec 29, 2024
3774cbb
Merge branch 'ACEsuit:main' into lbfgs-multi-gpu
vue1999 Dec 29, 2024
be70e51
workaround for shampoo's lack of checkpoint handling for testing
ttompa Dec 29, 2024
bc02bcb
add wolfe line search, remove debug logs, remove redundant barriers
ttompa Dec 30, 2024
99c8c8c
adjust normalisation to handle varying batch sizes
ttompa Dec 30, 2024
482a864
fix single GPU LBFGS training
ttompa Dec 30, 2024
45fee81
remove redundant barrier
ttompa Dec 30, 2024
2f2be1d
fix spelling
ttompa Dec 30, 2024
f252276
remove debug logging
ttompa Dec 30, 2024
9923ba1
adjust shampoo params for testing
ttompa Dec 30, 2024
4532b12
adjust shampoo params for testing
ttompa Dec 30, 2024
1bf4bb7
Merge branch 'ACEsuit:main' into lbfgs-multi-gpu
ttompa Jan 12, 2025
63a3c7c
remove shampoo related code
ttompa Jan 12, 2025
4b9ee30
cleanup
ttompa Jan 12, 2025
8b0b396
simplify LBFGS option
ttompa Jan 12, 2025
8e26459
cleanup
ttompa Jan 12, 2025
0857727
fix restarting LBFGS checkpoints
ttompa Jan 12, 2025
aa392a6
reduce iteration count for spice test, always save last checkpoint
ttompa Jan 15, 2025
f2243bc
disable drop_last when using LBFGS, adjust normalisation to account f…
ttompa Jan 20, 2025
5da7dff
Fix issue by initialising opt_start_epoch as None
vue1999 Jan 21, 2025
1b138ce
fix formatting to make CI checks happy
Jan 21, 2025
24c9bec
Update run_train.py
vue1999 Feb 25, 2025
e55411a
add unit test for LBFGS
ttompa Mar 14, 2025
2acf7df
Merge branch 'main' into lbfgs-multi-gpu
ttompa Mar 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.nn.functional
from e3nn.util import jit
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import LBFGS
from torch.utils.data import ConcatDataset
from torch_ema import ExponentialMovingAverage

Expand Down Expand Up @@ -639,7 +640,7 @@ def run(args) -> None:
dataset=train_sets[head_config.head_name],
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
drop_last=(not args.lbfgs),
pin_memory=args.pin_memory,
num_workers=args.num_workers,
generator=torch.Generator().manual_seed(args.seed),
Expand All @@ -655,7 +656,7 @@ def run(args) -> None:
num_replicas=world_size,
rank=rank,
shuffle=True,
drop_last=True,
drop_last=(not args.lbfgs),
seed=args.seed,
)
valid_samplers = {}
Expand All @@ -674,7 +675,7 @@ def run(args) -> None:
batch_size=args.batch_size,
sampler=train_sampler,
shuffle=(train_sampler is None),
drop_last=(train_sampler is None),
drop_last=(train_sampler is None and not args.lbfgs),
pin_memory=args.pin_memory,
num_workers=args.num_workers,
generator=torch.Generator().manual_seed(args.seed),
Expand Down Expand Up @@ -746,6 +747,8 @@ def run(args) -> None:
)

start_epoch = 0
restart_lbfgs = False
opt_start_epoch = None
if args.restart_latest:
try:
opt_start_epoch = checkpoint_handler.load_latest(
Expand All @@ -754,11 +757,14 @@ def run(args) -> None:
device=device,
)
except Exception: # pylint: disable=W0703
opt_start_epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=False,
device=device,
)
try:
opt_start_epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=False,
device=device,
)
except Exception: # pylint: disable=W0703
restart_lbfgs = True
if opt_start_epoch is not None:
start_epoch = opt_start_epoch

Expand All @@ -769,6 +775,21 @@ def run(args) -> None:
for group in optimizer.param_groups:
group["lr"] = args.lr

if args.lbfgs:
logging.info("Switching optimizer to LBFGS")
optimizer = LBFGS(model.parameters(),
history_size=200,
max_iter=20,
line_search_fn="strong_wolfe")
if restart_lbfgs:
opt_start_epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=False,
device=device,
)
if opt_start_epoch is not None:
start_epoch = opt_start_epoch

if args.wandb:
setup_wandb(args)
if args.distributed:
Expand Down
6 changes: 6 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
default=None,
dest="start_swa",
)
parser.add_argument(
"--lbfgs",
help="Switch to L-BFGS optimizer",
action="store_true",
default=False,
)
parser.add_argument(
"--ema",
help="use Exponential Moving Average",
Expand Down
156 changes: 139 additions & 17 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
import torch.distributed
from torch.nn.parallel import DistributedDataParallel
from torch.optim import LBFGS
from torch.optim.swa_utils import SWALR, AveragedModel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
Expand Down Expand Up @@ -182,7 +183,6 @@ def train(
epoch = start_epoch

# log validation loss before _any_ training
valid_loss = 0.0
for valid_loader_name, valid_loader in valid_loaders.items():
valid_loss_head, eval_metrics = evaluate(
model=model,
Expand Down Expand Up @@ -230,6 +230,7 @@ def train(
ema=ema,
logger=logger,
device=device,
distributed=distributed,
distributed_model=distributed_model,
rank=rank,
)
Expand All @@ -247,7 +248,6 @@ def train(
if "ScheduleFree" in type(optimizer).__name__:
optimizer.eval()
with param_context:
valid_loss = 0.0
wandb_log_dict = {}
for valid_loader_name, valid_loader in valid_loaders.items():
valid_loss_head, eval_metrics = evaluate(
Expand Down Expand Up @@ -342,25 +342,45 @@ def train_one_epoch(
ema: Optional[ExponentialMovingAverage],
logger: MetricsLogger,
device: torch.device,
distributed: bool,
distributed_model: Optional[DistributedDataParallel] = None,
rank: Optional[int] = 0,
) -> None:
model_to_train = model if distributed_model is None else distributed_model
for batch in data_loader:
_, opt_metrics = take_step(

if isinstance(optimizer, LBFGS):
_, opt_metrics = take_step_lbfgs(
model=model_to_train,
loss_fn=loss_fn,
batch=batch,
data_loader=data_loader,
optimizer=optimizer,
ema=ema,
output_args=output_args,
max_grad_norm=max_grad_norm,
device=device,
distributed=distributed,
rank=rank,
)
opt_metrics["mode"] = "opt"
opt_metrics["epoch"] = epoch
if rank == 0:
logger.log(opt_metrics)
else:
for batch in data_loader:
_, opt_metrics = take_step(
model=model_to_train,
loss_fn=loss_fn,
batch=batch,
optimizer=optimizer,
ema=ema,
output_args=output_args,
max_grad_norm=max_grad_norm,
device=device,
)
opt_metrics["mode"] = "opt"
opt_metrics["epoch"] = epoch
if rank == 0:
logger.log(opt_metrics)


def take_step(
Expand All @@ -375,19 +395,25 @@ def take_step(
) -> Tuple[float, Dict[str, Any]]:
start_time = time.time()
batch = batch.to(device)
optimizer.zero_grad(set_to_none=True)
batch_dict = batch.to_dict()
output = model(
batch_dict,
training=True,
compute_force=output_args["forces"],
compute_virials=output_args["virials"],
compute_stress=output_args["stress"],
)
loss = loss_fn(pred=output, ref=batch)
loss.backward()
if max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)

def closure():
optimizer.zero_grad(set_to_none=True)
output = model(
batch_dict,
training=True,
compute_force=output_args["forces"],
compute_virials=output_args["virials"],
compute_stress=output_args["stress"],
)
loss = loss_fn(pred=output, ref=batch)
loss.backward()
if max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)

return loss

loss = closure()
optimizer.step()

if ema is not None:
Expand All @@ -401,6 +427,102 @@ def take_step(
return loss, loss_dict


def take_step_lbfgs(
model: torch.nn.Module,
loss_fn: torch.nn.Module,
data_loader: DataLoader,
optimizer: torch.optim.Optimizer,
ema: Optional[ExponentialMovingAverage],
output_args: Dict[str, bool],
max_grad_norm: Optional[float],
device: torch.device,
distributed: bool,
rank: int,
) -> Tuple[float, Dict[str, Any]]:
start_time = time.time()
logging.debug(
f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB"
)

total_sample_count = 0
for batch in data_loader:
total_sample_count += batch.num_graphs

if distributed:
global_sample_count = torch.tensor(total_sample_count, device=device)
torch.distributed.all_reduce(
global_sample_count, op=torch.distributed.ReduceOp.SUM
)
total_sample_count = global_sample_count.item()

signal = torch.zeros(1, device=device) if distributed else None

def closure():
if distributed:
if rank == 0:
signal.fill_(1)
torch.distributed.broadcast(signal, src=0)

for param in model.parameters():
torch.distributed.broadcast(param.data, src=0)

optimizer.zero_grad(set_to_none=True)
total_loss = torch.tensor(0.0, device=device)

# Process each batch and then collect the results we pass to the optimizer
for batch in data_loader:
batch = batch.to(device)
batch_dict = batch.to_dict()
output = model(
batch_dict,
training=True,
compute_force=output_args["forces"],
compute_virials=output_args["virials"],
compute_stress=output_args["stress"],
)
batch_loss = loss_fn(pred=output, ref=batch)
batch_loss = batch_loss * (batch.num_graphs / total_sample_count)

batch_loss.backward()
total_loss += batch_loss

if max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)

if distributed:
torch.distributed.all_reduce(total_loss, op=torch.distributed.ReduceOp.SUM)
return total_loss

if distributed:
if rank == 0:
loss = optimizer.step(closure)
signal.fill_(0)
torch.distributed.broadcast(signal, src=0)
else:
while True:
# Other ranks wait for signals from rank 0
torch.distributed.broadcast(signal, src=0)
if signal.item() == 0:
break
if signal.item() == 1:
loss = closure()

for param in model.parameters():
torch.distributed.broadcast(param.data, src=0)
else:
loss = optimizer.step(closure)

if ema is not None:
ema.update()

loss_dict = {
"loss": to_numpy(loss),
"time": time.time() - start_time,
}

return loss, loss_dict


def evaluate(
model: torch.nn.Module,
loss_fn: torch.nn.Module,
Expand Down
Loading