Skip to content

Commit

Permalink
included symbreak for chgnet training to improve code coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
kenko911 committed Oct 22, 2024
2 parents a997119 + 9632254 commit a2c6b80
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 36 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ pymatgen==2024.10.3
ase==3.23.0
pydantic==2.9.2
torchdata==0.7.1
boto3==1.35.39
boto3==1.35.44
numpy==1.26.4
sympy==1.13.3
1 change: 0 additions & 1 deletion src/matgl/ext/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ def __init__(
stress_weight=stress_weight, # type: ignore
)
self.relax_cell = relax_cell
self.potential = potential
self.ase_adaptor = AseAtomsAdaptor()

def relax(
Expand Down
75 changes: 42 additions & 33 deletions src/matgl/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def loss_fn(
loss: nn.Module,
labels: tuple,
preds: tuple,
num_atoms: int | None = None,
num_atoms: torch.Tensor | None = None,
):
"""Compute losses for EFS.
Expand All @@ -462,14 +462,31 @@ def loss_fn(
"""
# labels and preds are (energy, force, stress, (optional) site_wise)
e_loss = self.loss(labels[0] / num_atoms, preds[0] / num_atoms, **self.loss_params)
f_loss = self.loss(labels[1], preds[1], **self.loss_params)
if num_atoms is None:
num_atoms = torch.ones_like(preds[0])
if self.allow_missing_labels:
valid_labels, valid_preds = [], []
for index, label in enumerate(labels):
valid_value_indices = ~torch.isnan(label)
valid_labels.append(label[valid_value_indices])
if index == 0:
valid_num_atoms = num_atoms[valid_value_indices]
pred = preds[index].view(1) if preds[index].shape == torch.Size([]) else preds[index]
else:
pred = preds[index]
valid_preds.append(pred[valid_value_indices])
else:
valid_labels, valid_preds = list(labels), list(preds)
valid_num_atoms = num_atoms

e_loss = self.loss(valid_labels[0] / valid_num_atoms, valid_preds[0] / valid_num_atoms, **self.loss_params)
f_loss = self.loss(valid_labels[1], valid_preds[1], **self.loss_params)

e_mae = self.mae(labels[0] / num_atoms, preds[0] / num_atoms)
f_mae = self.mae(labels[1], preds[1])
e_mae = self.mae(valid_labels[0] / valid_num_atoms, valid_preds[0] / valid_num_atoms)
f_mae = self.mae(valid_labels[1], valid_preds[1])

e_rmse = self.rmse(labels[0] / num_atoms, preds[0] / num_atoms)
f_rmse = self.rmse(labels[1], preds[1])
e_rmse = self.rmse(valid_labels[0] / valid_num_atoms, valid_preds[0] / valid_num_atoms)
f_rmse = self.rmse(valid_labels[1], valid_preds[1])

s_mae = torch.zeros(1)
s_rmse = torch.zeros(1)
Expand All @@ -480,36 +497,28 @@ def loss_fn(
total_loss = self.energy_weight * e_loss + self.force_weight * f_loss

if self.model.calc_stresses:
s_loss = loss(labels[2], preds[2], **self.loss_params)
s_mae = self.mae(labels[2], preds[2])
s_rmse = self.rmse(labels[2], preds[2])
s_loss = loss(valid_labels[2], valid_preds[2], **self.loss_params)
s_mae = self.mae(valid_labels[2], valid_preds[2])
s_rmse = self.rmse(valid_labels[2], valid_preds[2])
total_loss = total_loss + self.stress_weight * s_loss

if self.model.calc_magmom:
if self.allow_missing_labels:
valid_values = ~torch.isnan(labels[3])
labels_3 = labels[3][valid_values]
preds_3 = preds[3][valid_values]
if self.model.calc_magmom and labels[3].numel() > 0:
if self.magmom_target == "symbreak":
m_loss = torch.min(
loss(valid_labels[3], valid_preds[3], **self.loss_params),
loss(valid_labels[3], -valid_preds[3], **self.loss_params),
)
m_mae = torch.min(self.mae(valid_labels[3], valid_preds[3]), self.mae(valid_labels[3], -valid_preds[3]))
m_rmse = torch.min(
self.rmse(valid_labels[3], valid_preds[3]), self.rmse(valid_labels[3], -valid_preds[3])
)
else:
labels_3 = labels[3]
preds_3 = preds[3]

if len(labels_3) > 0:
if self.magmom_target == "symbreak":
m_loss = torch.min(
loss(labels_3, preds_3, **self.loss_params), loss(labels_3, -preds_3, **self.loss_params)
)
m_mae = torch.min(self.mae(labels_3, preds_3), self.mae(labels_3, -preds_3))
m_rmse = torch.min(self.rmse(labels_3, preds_3), self.rmse(labels_3, -preds_3))
else:
if self.magmom_target == "absolute":
labels_3 = torch.abs(labels_3)

m_loss = loss(labels_3, preds_3, **self.loss_params)
m_mae = self.mae(labels_3, preds_3)
m_rmse = self.rmse(labels_3, preds_3)
labels_3 = torch.abs(valid_labels[3]) if self.magmom_target == "absolute" else valid_labels[3]
m_loss = loss(labels_3, valid_preds[3], **self.loss_params)
m_mae = self.mae(labels_3, valid_preds[3])
m_rmse = self.rmse(labels_3, valid_preds[3])

total_loss = total_loss + self.magmom_weight * m_loss
total_loss = total_loss + self.magmom_weight * m_loss

return {
"Total_Loss": total_loss,
Expand Down
68 changes: 67 additions & 1 deletion tests/utils/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,9 @@ def test_chgnet_training(self, LiFePO4, BaNiO3):
generator=torch.Generator(device=device),
)
model = CHGNet(element_types=element_types, is_intensive=False)
lit_model = PotentialLightningModule(model=model, stress_weight=0.1, magmom_weight=0.1, include_line_graph=True)
lit_model = PotentialLightningModule(
model=model, stress_weight=0.1, magmom_weight=0.1, include_line_graph=True, magmom_target="symbreak"
)
# We will use CPU if MPS is available since there is a serious bug.
trainer = pl.Trainer(max_epochs=2, accelerator=device, inference_mode=False)

Expand Down Expand Up @@ -773,6 +775,70 @@ def test_chgnet_training_without_m(self, LiFePO4, BaNiO3):

self.teardown_class()

def test_chgnet_training_with_missing_label(self, LiFePO4, BaNiO3):
structures = [LiFePO4, BaNiO3] * 5
energies = [-2.0, -3.0] * 5
forces = [np.ones((len(s), 3)).tolist() for s in structures]
stresses = [np.zeros((3, 3)).tolist()] * len(structures)
magmoms = [np.ones((len(s), 1)).tolist() for s in structures]
# Create some missing labels
energies[2] = np.nan
forces[4] = (np.nan * np.ones((len(structures[4]), 3))).tolist()
stresses[6] = (np.nan * np.ones((3, 3))).tolist()
magmoms[8] = (np.nan * np.ones((len(structures[8]), 1))).tolist()
element_types = get_element_list([LiFePO4, BaNiO3])
converter = Structure2Graph(element_types=element_types, cutoff=6.0)
dataset = MGLDataset(
threebody_cutoff=3.0,
structures=structures,
converter=converter,
include_line_graph=True,
directed_line_graph=True,
labels={
"energies": energies,
"forces": forces,
"stresses": stresses,
"magmoms": magmoms,
},
save_cache=False,
)
train_data, val_data, test_data = split_dataset(
dataset,
frac_list=[0.8, 0.1, 0.1],
shuffle=True,
random_state=42,
)
train_loader, val_loader, test_loader = MGLDataLoader(
train_data=train_data,
val_data=val_data,
test_data=test_data,
collate_fn=partial(collate_fn_pes, include_magmom=True, include_line_graph=True),
batch_size=4,
num_workers=0,
generator=torch.Generator(device=device),
)
model = CHGNet(element_types=element_types, is_intensive=False)
lit_model = PotentialLightningModule(
model=model,
stress_weight=0.1,
magmom_weight=0.1,
include_line_graph=True,
allow_missing_labels=True,
)
# We will use CPU if MPS is available since there is a serious bug.
trainer = pl.Trainer(max_epochs=2, accelerator=device, inference_mode=False)

trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(lit_model, dataloaders=test_loader)

pred_LFP_energy = model.predict_structure(LiFePO4)
pred_BNO_energy = model.predict_structure(BaNiO3)

assert torch.any(pred_LFP_energy < 0)
assert torch.any(pred_BNO_energy < 0)

self.teardown_class()

@classmethod
def teardown_class(cls):
try:
Expand Down

0 comments on commit a2c6b80

Please sign in to comment.