Skip to content

Commit

Permalink
Avoiding crashes for PES training without stresses and update pretrai…
Browse files Browse the repository at this point in the history
…ned models (#168)

* Optimize the Atoms2Graph and fixed the np.meshgrid

* put unittests

* improve the _three_body.py and test_M3GNetCalculator in test_ase.py

* add cpu() in ase.py and compute.py to enable the GPU usage for MatGL-LAMMPS interface

* included the unit-test for hessian test_ase.py to improve the coverage score

* remove reducdant torch.unique for finding the maximum three_body index and little cleanup in united tests

* add united test for trainer.test and description in the example

* add option for PES training without stresses

* merged the changes and fix errors

* add backward compatibility for data_mean in pes.py

---------

Co-authored-by: Shyue Ping Ong <[email protected]>
  • Loading branch information
kenko911 and shyuep authored Sep 20, 2023
1 parent 13d71ca commit 375dcf3
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 26 deletions.
13 changes: 7 additions & 6 deletions matgl/apps/pes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import nn
from torch.autograd import grad

import matgl
from matgl.layers import AtomRef
from matgl.utils.io import IOMixIn

Expand Down Expand Up @@ -52,14 +53,14 @@ def __init__(
self.calc_site_wise = calc_site_wise
self.element_refs: AtomRef | None
if element_refs is not None:
self.element_refs = AtomRef(property_offset=element_refs)
self.element_refs = AtomRef(property_offset=torch.tensor(element_refs, dtype=matgl.float_th))
else:
self.element_refs = None
data_mean = data_mean or 0
data_std = data_std or 1

self.data_mean = data_mean.clone().detach() if isinstance(data_mean, torch.Tensor) else torch.tensor(data_mean)
self.data_std = data_std.clone().detach() if isinstance(data_std, torch.Tensor) else torch.tensor(data_std)
# for backward compatibility
if data_mean is None:
data_mean = 0.0
self.register_buffer("data_mean", torch.tensor(data_mean, dtype=matgl.float_th))
self.register_buffer("data_std", torch.tensor(data_std, dtype=matgl.float_th))

def forward(
self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, l_g: dgl.DGLGraph | None = None
Expand Down
12 changes: 8 additions & 4 deletions matgl/graph/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,18 @@ def collate_fn(batch, include_line_graph: bool = False):
return g, labels, state_attr


def collate_fn_efs(batch):
def collate_fn_efs(batch, include_stress: bool = True):
"""Merge a list of dgl graphs to form a batch."""
graphs, line_graphs, state_attr, labels = map(list, zip(*batch))
g = dgl.batch(graphs)
l_g = dgl.batch(line_graphs)
e = torch.tensor([d["energies"] for d in labels])
f = torch.vstack([d["forces"] for d in labels])
s = torch.vstack([d["stresses"] for d in labels])
e = torch.tensor([d["energies"] for d in labels]) # type: ignore
f = torch.vstack([d["forces"] for d in labels]) # type: ignore
s = (
torch.vstack([d["stresses"] for d in labels]) # type: ignore
if include_stress is True
else torch.tensor(np.zeros(e.size(dim=0)), dtype=matgl.float_th)
)
state_attr = torch.stack(state_attr)
return g, l_g, state_attr, e, f, s

Expand Down
9 changes: 5 additions & 4 deletions matgl/layers/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, max_l: int, max_n: int = 5, cutoff: float = 5.0, smooth: bool
super().__init__()
self.max_l = max_l
self.max_n = max_n
self.cutoff = torch.tensor(cutoff)
self.register_buffer("cutoff", torch.tensor(cutoff))
self.smooth = smooth
if smooth:
self.funcs = self._calculate_smooth_symbolic_funcs()
Expand Down Expand Up @@ -116,7 +116,7 @@ def _call_sbf(self, r):
results = []
factor = torch.tensor(sqrt(2.0 / self.cutoff**3))
for i in range(self.max_l):
root = roots[i]
root = torch.tensor(roots[i])
func = self.funcs[i]
func_add1 = self.funcs[i + 1]
results.append(
Expand Down Expand Up @@ -218,7 +218,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return result / self.interval * self.scale_factor


class SphericalHarmonicsFunction:
class SphericalHarmonicsFunction(nn.Module):
"""Spherical Harmonics function."""

def __init__(self, max_l: int, use_phi: bool = True):
Expand All @@ -228,6 +228,7 @@ def __init__(self, max_l: int, use_phi: bool = True):
use_phi: bool, whether to use the polar angle. If not,
the function will compute `Y_l^0`.
"""
super().__init__()
self.max_l = max_l
self.use_phi = use_phi
funcs = []
Expand All @@ -244,7 +245,7 @@ def __init__(self, max_l: int, use_phi: bool = True):
self.funcs = [sympy.lambdify([costheta, phi], i, [{"conjugate": _conjugate}, torch]) for i in self.orig_funcs]
self.funcs[0] = _y00

def __call__(self, costheta, phi=None):
def forward(self, costheta, phi=None):
"""Args:
costheta: Cosine of the azimuthal angle
phi: torch.Tensor, the polar angle.
Expand Down
2 changes: 1 addition & 1 deletion pretrained_models/M3GNet-MP-2021.2.8-PES/model.json
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
"data_std": "tensor(0.3896)"
}
},
"data_mean": null,
"data_mean": 0.0,
"data_std": 0.389599794401098,
"element_refs": [
-3.47784146,
Expand Down
Binary file modified pretrained_models/M3GNet-MP-2021.2.8-PES/model.pt
Binary file not shown.
8 changes: 4 additions & 4 deletions tests/ext/test_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def test_M3GNetCalculator(MoS):
adaptor = AseAtomsAdaptor()
s_ase = adaptor.get_atoms(MoS) # type: ignore
ff = load_model("M3GNet-MP-2021.2.8-PES")
ff = load_model("pretrained_models/M3GNet-MP-2021.2.8-PES/")
ff.calc_hessian = True
calc = M3GNetCalculator(potential=ff)
s_ase.set_calculator(calc)
Expand All @@ -28,7 +28,7 @@ def test_M3GNetCalculator(MoS):
def test_M3GNetCalculator_mol(AcAla3NHMe):
adaptor = AseAtomsAdaptor()
mol = adaptor.get_atoms(AcAla3NHMe)
ff = load_model("M3GNet-MP-2021.2.8-PES")
ff = load_model("pretrained_models/M3GNet-MP-2021.2.8-PES/")
calc = M3GNetCalculator(potential=ff)
mol.set_calculator(calc)
assert [mol.get_potential_energy().size] == [1]
Expand All @@ -37,7 +37,7 @@ def test_M3GNetCalculator_mol(AcAla3NHMe):


def test_Relaxer(MoS):
pot = load_model("M3GNet-MP-2021.2.8-PES")
pot = load_model("pretrained_models/M3GNet-MP-2021.2.8-PES/")
r = Relaxer(pot)
results = r.relax(MoS, traj_file="MoS_relax.traj")
s = results["final_structure"]
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_get_graph_from_atoms_mol():


def test_molecular_dynamics(MoS):
pot = load_model("M3GNet-MP-2021.2.8-PES")
pot = load_model("pretrained_models/M3GNet-MP-2021.2.8-PES/")
for ensemble in ["nvt", "nvt_langevin", "nvt_andersen", "npt", "npt_berendsen", "npt_nose_hoover"]:
md = MolecularDynamics(MoS, potential=pot, ensemble=ensemble, taut=0.1, taup=0.1, compressibility_au=10)
md.run(10)
Expand Down
44 changes: 38 additions & 6 deletions tests/graph/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@
from pymatgen.core import Molecule

from matgl.ext.pymatgen import Molecule2Graph, Structure2Graph, get_element_list
from matgl.graph.data import (
M3GNetDataset,
MEGNetDataset,
MGLDataLoader,
collate_fn,
)
from matgl.graph.data import M3GNetDataset, MEGNetDataset, MGLDataLoader, collate_fn, collate_fn_efs

module_dir = os.path.dirname(os.path.abspath(__file__))

Expand Down Expand Up @@ -243,6 +238,43 @@ def test_m3gnet_dataloader(self, LiFePO4, BaNiO3):
os.remove("dgl_line_graph.bin")
os.remove("state_attr.pt")

def test_m3gnet_dataloader_without_stresses(self, LiFePO4, BaNiO3):
structures = [LiFePO4, BaNiO3] * 10
energies = np.zeros(20).tolist()
f1 = np.zeros((28, 3)).tolist()
f2 = np.zeros((10, 3)).tolist()
np.zeros((3, 3)).tolist()
forces = [f1, f2, f1, f2, f1, f2, f1, f2, f1, f2, f1, f2, f1, f2, f1, f2, f1, f2, f1, f2]
element_types = get_element_list([LiFePO4, BaNiO3])
cry_graph = Structure2Graph(element_types=element_types, cutoff=4.0)
dataset = M3GNetDataset(
structures=structures,
converter=cry_graph,
threebody_cutoff=4.0,
labels={"energies": energies, "forces": forces},
)
train_data, val_data, test_data = split_dataset(
dataset,
frac_list=[0.8, 0.1, 0.1],
shuffle=True,
random_state=42,
)
my_collate_fn = partial(collate_fn_efs, include_stress=False)
train_loader, val_loader, test_loader = MGLDataLoader(
train_data=train_data,
val_data=val_data,
test_data=test_data,
collate_fn=my_collate_fn,
batch_size=2,
num_workers=1,
)
assert len(train_loader) == 8
assert len(val_loader) == 1
assert len(test_loader) == 1
os.remove("dgl_graph.bin")
os.remove("dgl_line_graph.bin")
os.remove("state_attr.pt")

def test_m3gnet_property_dataloader(self, LiFePO4, BaNiO3):
structures = [LiFePO4, BaNiO3] * 10
e_form = np.zeros(20)
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_get_available_pretrained_models():

def test_load_model():
# Load model from name.
model = load_model("M3GNet-MP-2021.2.8-PES")
model = load_model("M3GNet-MP-2021.2.8-DIRECT-PES")
assert issubclass(model.__class__, torch.nn.Module)

# Load model from a full path.
Expand Down
48 changes: 48 additions & 0 deletions tests/utils/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,54 @@ def test_m3gnet_training(self, LiFePO4, BaNiO3):
os.remove("dgl_line_graph.bin")
os.remove("state_attr.pt")

def test_m3gnet_training_without_stress(self, LiFePO4, BaNiO3):
isolated_atom = Structure(Lattice.cubic(10.0), ["Li"], [[0, 0, 0]])
two_body = Structure(Lattice.cubic(10.0), ["Li", "Li"], [[0, 0, 0], [0.2, 0, 0]])
structures = [LiFePO4, BaNiO3] * 5 + [isolated_atom, two_body]
energies = [-2.0, -3.0] * 5 + [-1.0, -1.5]
forces = [np.zeros((len(s), 3)).tolist() for s in structures]
element_types = get_element_list([LiFePO4, BaNiO3])
converter = Structure2Graph(element_types=element_types, cutoff=5.0)
dataset = M3GNetDataset(
threebody_cutoff=4.0,
structures=structures,
converter=converter,
labels={"energies": energies, "forces": forces},
)
train_data, val_data, test_data = split_dataset(
dataset,
frac_list=[0.8, 0.1, 0.1],
shuffle=True,
random_state=42,
)
my_collate_fn = partial(collate_fn_efs, include_stress=False)
train_loader, val_loader, test_loader = MGLDataLoader(
train_data=train_data,
val_data=val_data,
test_data=test_data,
collate_fn=my_collate_fn,
batch_size=2,
num_workers=0,
generator=torch.Generator(device=device),
)
model = M3GNet(element_types=element_types, is_intensive=False)
lit_model = PotentialLightningModule(model=model, stress_weight=0.0)
# We will use CPU if MPS is available since there is a serious bug.
trainer = pl.Trainer(max_epochs=5, accelerator=device, inference_mode=False)

trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(lit_model, dataloaders=test_loader)

pred_LFP_energy = model.predict_structure(LiFePO4)
pred_BNO_energy = model.predict_structure(BaNiO3)

# We are not expecting accuracy with 2 epochs. This just tests that the energy is actually < 0.
assert pred_LFP_energy < 0
assert pred_BNO_energy < 0
os.remove("dgl_graph.bin")
os.remove("dgl_line_graph.bin")
os.remove("state_attr.pt")

def test_m3gnet_property_training(self, LiFePO4, BaNiO3):
isolated_atom = Structure(Lattice.cubic(10.0), ["Li"], [[0, 0, 0]])
structures = [LiFePO4] * 5 + [BaNiO3] * 5 + [isolated_atom]
Expand Down

0 comments on commit 375dcf3

Please sign in to comment.