Skip to content

Commit

Permalink
Merge branch 'main' of github.com:ACEsuit/mace-jax
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Feb 14, 2023
2 parents d37d428 + 96d6265 commit a224c12
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 183 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ nox

## Installation

From github:

```sh
pip install git+https://github.com/ACEsuit/mace-jax
```

Or locally:

```sh
python setup.py develop
```
Expand Down
12 changes: 5 additions & 7 deletions configs/aspirin.gin
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ datasets.n_mantissa_bits = 3
model.radial_basis = @bessel_basis
bessel_basis.number = 8

# model.radial_envelope = @u_envelope # 10. @ epoch 350
# u_envelope.p = 5
model.radial_envelope = @soft_envelope # 9.5 @ epoch 350
model.radial_envelope = @soft_envelope

model.symmetric_tensor_product_basis = True # symmetric is slightly worse but slightly faster
model.off_diagonal = False
Expand Down Expand Up @@ -52,15 +50,15 @@ loss.stress_weight = 0.0

optimizer.algorithm = @amsgrad # @adam is slightly worse
optimizer.lr = 0.01
optimizer.max_num_epochs = 100
optimizer.steps_per_interval = 1024
optimizer.max_num_intervals = 100
# optimizer.weight_decay = 5e-7
# optimizer.scheduler = @piecewise_constant_schedule
# piecewise_constant_schedule.boundaries_and_scales = {
# 100: 0.1, # divide the learning rate by 10 after 100 epochs
# 1000: 0.1, # divide the learning rate by 10 after 1000 epochs
# 100: 0.1, # divide the learning rate by 10 after 100 intervals
# 1000: 0.1, # divide the learning rate by 10 after 1000 intervals
# }

train.eval_interval = 5
train.patience = 2048
train.ema_decay = 0.99
train.eval_train = False # if True, evaluates the whole training set at each eval_interval
Expand Down
14 changes: 6 additions & 8 deletions configs/aspirin_small.gin
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ datasets.n_mantissa_bits = 3
model.radial_basis = @bessel_basis
bessel_basis.number = 8

# model.radial_envelope = @u_envelope # 10. @ epoch 350
# u_envelope.p = 5
model.radial_envelope = @soft_envelope # 9.5 @ epoch 350
model.radial_envelope = @soft_envelope

model.symmetric_tensor_product_basis = True # symmetric is slightly worse but slightly faster
model.off_diagonal = False
Expand Down Expand Up @@ -53,17 +51,17 @@ loss.stress_weight = 0.0

optimizer.algorithm = @sgd
optimizer.lr = 0.01
optimizer.max_num_epochs = 2
optimizer.steps_per_interval = 100
optimizer.max_num_intervals = 2
# optimizer.weight_decay = 5e-7
# optimizer.scheduler = @piecewise_constant_schedule
# piecewise_constant_schedule.boundaries_and_scales = {
# 100: 0.1, # divide the learning rate by 10 after 100 epochs
# 1000: 0.1, # divide the learning rate by 10 after 1000 epochs
# 100: 0.1, # divide the learning rate by 10 after 100 intervals
# 1000: 0.1, # divide the learning rate by 10 after 1000 intervals
# }

train.eval_interval = 5
train.patience = 2048
train.ema_decay = 0.99
train.eval_train = False # if True, evaluates the whole training set at each eval_interval
train.eval_train = True # if True, evaluates the whole training set at each eval_interval
train.eval_test = False
train.log_errors = "PerAtomMAE"
29 changes: 12 additions & 17 deletions mace_jax/plot_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,12 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--path", help="path to results file or directory", required=True
)
parser.add_argument(
"--min_epoch", help="minimum epoch", default=50, type=int, required=False
)
return parser.parse_args()


def plot(data: pd.DataFrame, min_epoch: int, output_path: str) -> None:
data = data[data["epoch"] > min_epoch]

def plot(data: pd.DataFrame, output_path: str) -> None:
data = (
data.groupby(["path", "name", "mode", "epoch"])
data.groupby(["path", "name", "mode", "interval"])
.agg([np.mean, np.std])
.reset_index()
)
Expand All @@ -68,29 +63,29 @@ def plot(data: pd.DataFrame, min_epoch: int, output_path: str) -> None:

ax = axes[0]
ax.plot(
valid_data["epoch"],
valid_data["interval"],
valid_data["loss"]["mean"],
color=colors[0],
zorder=1,
label="Validation",
)
# ax.fill_between(
# x=valid_data["epoch"],
# x=valid_data["interval"],
# y1=valid_data["loss"]["mean"] - valid_data["loss"]["std"],
# y2=valid_data["loss"]["mean"] + valid_data["loss"]["std"],
# alpha=0.5,
# zorder=-1,
# color=colors[0],
# )
ax.plot(
train_data["epoch"],
train_data["interval"],
train_data["loss"]["mean"],
color=colors[3],
zorder=1,
label="Training",
)
# ax.fill_between(
# x=train_data["epoch"],
# x=train_data["interval"],
# y1=train_data["loss"]["mean"] - train_data["loss"]["std"],
# y2=train_data["loss"]["mean"] + train_data["loss"]["std"],
# alpha=0.5,
Expand All @@ -100,35 +95,35 @@ def plot(data: pd.DataFrame, min_epoch: int, output_path: str) -> None:

ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel("Epoch")
ax.set_xlabel("Interval")
ax.set_ylabel("Loss")
ax.legend()

ax = axes[1]
ax.plot(
valid_data["epoch"],
valid_data["interval"],
valid_data["mae_e"]["mean"],
color=colors[1],
zorder=1,
label="MAE Energy [eV]",
)
# ax.fill_between(
# x=valid_data["epoch"],
# x=valid_data["interval"],
# y1=valid_data["mae_e"]["mean"] - valid_data["mae_e"]["std"],
# y2=valid_data["mae_e"]["mean"] + valid_data["mae_e"]["std"],
# alpha=0.5,
# zorder=-1,
# color=colors[1],
# )
ax.plot(
valid_data["epoch"],
valid_data["interval"],
valid_data["mae_f"]["mean"],
color=colors[2],
zorder=1,
label="MAE Forces [eV/Å]",
)
# ax.fill_between(
# x=valid_data["epoch"],
# x=valid_data["interval"],
# y1=valid_data["mae_f"]["mean"] - valid_data["mae_f"]["std"],
# y2=valid_data["mae_f"]["mean"] + valid_data["mae_f"]["std"],
# alpha=0.5,
Expand Down Expand Up @@ -166,7 +161,7 @@ def main():
)

for (path, name), group in data.groupby(["path", "name"]):
plot(group, min_epoch=args.min_epoch, output_path=f"{path}/{name}.pdf")
plot(group, output_path=f"{path}/{name}.pdf")


if __name__ == "__main__":
Expand Down
15 changes: 10 additions & 5 deletions mace_jax/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys

import gin
import jax

import mace_jax
from mace_jax import tools
Expand Down Expand Up @@ -31,19 +32,22 @@ def main():
train_loader, valid_loader, test_loader, atomic_energies_dict, r_max = datasets()

model_fn, params, num_message_passing = model(
r_max, atomic_energies_dict, train_loader.graphs, initialize_seed=seed
r_max=r_max,
atomic_energies_dict=atomic_energies_dict,
train_graphs=train_loader.graphs,
initialize_seed=seed,
)

params = reload(params)

predictor = lambda w, g: tools.predict_energy_forces_stress(
lambda *x: model_fn(w, *x), g
predictor = jax.jit(
lambda w, g: tools.predict_energy_forces_stress(lambda *x: model_fn(w, *x), g)
)

if checks(predictor, params, train_loader):
return

gradient_transform, max_num_epochs = optimizer(train_loader.approx_length())
gradient_transform, steps_per_interval, max_num_intervals = optimizer()
optimizer_state = gradient_transform.init(params)

logging.info(f"Number of parameters: {tools.count_parameters(params)}")
Expand All @@ -59,7 +63,8 @@ def main():
valid_loader,
test_loader,
gradient_transform,
max_num_epochs,
max_num_intervals,
steps_per_interval,
logger,
directory,
tag,
Expand Down
2 changes: 2 additions & 0 deletions mace_jax/tools/gin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def datasets(
valid_fraction: float = None,
valid_num: int = None,
test_path: str = None,
test_num: int = None,
seed: int = 1234,
energy_key: str = "energy",
forces_key: str = "forces",
Expand Down Expand Up @@ -93,6 +94,7 @@ def datasets(
energy_key=energy_key,
forces_key=forces_key,
extract_atomic_energies=False,
num_configs=test_num,
prefactor_stress=prefactor_stress,
remap_stress=remap_stress,
)
Expand Down
Loading

0 comments on commit a224c12

Please sign in to comment.