Skip to content

Commit

Permalink
Modified collate_fn for predicting multiple values per target propert…
Browse files Browse the repository at this point in the history
…y with M3GNet model and included a unit test
  • Loading branch information
kenko911 committed Nov 17, 2023
1 parent 4587b9b commit 3451e0d
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
8 changes: 6 additions & 2 deletions matgl/graph/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@
from matgl.graph.converters import GraphConverter


def collate_fn(batch, include_line_graph: bool = False):
def collate_fn(batch, include_line_graph: bool = False, multiple_values_per_target: bool = False):
"""Merge a list of dgl graphs to form a batch."""
if include_line_graph:
graphs, lattices, line_graphs, state_attr, labels = map(list, zip(*batch))
else:
graphs, lattices, state_attr, labels = map(list, zip(*batch))
g = dgl.batch(graphs)
labels = torch.tensor([next(iter(d.values())) for d in labels], dtype=matgl.float_th) # type: ignore
labels = (
torch.vstack([next(iter(d.values())) for d in labels])
if multiple_values_per_target
else torch.tensor([next(iter(d.values())) for d in labels], dtype=matgl.float_th)
)
state_attr = torch.stack(state_attr)
lat = lattices[0] if g.batch_size == 1 else torch.squeeze(torch.stack(lattices))
if include_line_graph:
Expand Down
67 changes: 67 additions & 0 deletions tests/utils/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,73 @@ def test_m3gnet_property_training(self, LiFePO4, BaNiO3):
assert "MAE" in results[0][0]
self.teardown_class()

def test_m3gnet_property_trainin_multiple_values_per_target(self, LiFePO4, BaNiO3):
isolated_atom = Structure(Lattice.cubic(10.0), ["Li"], [[0, 0, 0]])
structures = [LiFePO4] * 5 + [BaNiO3] * 5 + [isolated_atom]
label = np.full((11, 5), -1.0).tolist() # Artificial dataset.
element_types = get_element_list([LiFePO4, BaNiO3])
converter = Structure2Graph(element_types=element_types, cutoff=5.0)
dataset = M3GNetDataset(
threebody_cutoff=4.0, structures=structures, converter=converter, labels={"multiple_values": label}
)
train_data, val_data, test_data = split_dataset(
dataset,
frac_list=[0.8, 0.1, 0.1],
shuffle=True,
random_state=42,
)
# This modification is required for M3GNet property dataset
collate_fn_property = partial(collate_fn, include_line_graph=True, multiple_values_per_target=True)
train_loader, val_loader, test_loader = MGLDataLoader(
train_data=train_data,
val_data=val_data,
test_data=test_data,
collate_fn=collate_fn_property,
batch_size=2,
num_workers=0,
generator=torch.Generator(device=device),
)
model = M3GNet(
element_types=element_types,
is_intensive=True,
readout_type="set2set",
ntargets=5,
)
lit_model = ModelLightningModule(model=model)
# We will use CPU if MPS is available since there is a serious bug.
trainer = pl.Trainer(max_epochs=2, accelerator=device)

trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

pred_LFP_energy = model.predict_structure(LiFePO4)
pred_BNO_energy = model.predict_structure(BaNiO3)

# We are not expecting accuracy with 2 epochs. This just tests that the energy is actually < 0.
assert torch.any(pred_LFP_energy < 0)
assert torch.any(pred_BNO_energy < 0)

results = trainer.predict(model=lit_model, dataloaders=test_loader)

assert "MAE" in results[0][0]

lit_model = ModelLightningModule(model=model, loss="l1_loss")
# We will use CPU if MPS is available since there is a serious bug.
trainer = pl.Trainer(max_epochs=2, accelerator=device)

trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

pred_LFP_energy = model.predict_structure(LiFePO4)
pred_BNO_energy = model.predict_structure(BaNiO3)

# We are not expecting accuracy with 2 epochs. This just tests that the energy is actually < 0.
assert torch.any(pred_LFP_energy < 0)
assert torch.any(pred_BNO_energy < 0)

results = trainer.predict(model=lit_model, dataloaders=test_loader)

assert "MAE" in results[0][0]
self.teardown_class()

@classmethod
def teardown_class(cls):
for fn in ("dgl_graph.bin", "lattice.pt", "dgl_line_graph.bin", "state_attr.pt", "labels.json"):
Expand Down

0 comments on commit 3451e0d

Please sign in to comment.