Skip to content

Commit

Permalink
Merge pull request #237 from laserkelvin/helper-callback
Browse files Browse the repository at this point in the history
Quality of life and helper callback functions
  • Loading branch information
laserkelvin authored Jul 1, 2024
2 parents 430712f + d53cde4 commit 0e3a640
Show file tree
Hide file tree
Showing 6 changed files with 864 additions and 19 deletions.
65 changes: 65 additions & 0 deletions examples/callbacks/autocorrelation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

from matsciml.datasets.transforms import DistancesTransform, PointCloudToGraphTransform
from matsciml.lightning.data_utils import MatSciMLDataModule
from matsciml.lightning.callbacks import ModelAutocorrelation
from matsciml.models import SchNet
from matsciml.models.base import ScalarRegressionTask

"""
This script demonstrates the use of the `ModelAutocorrelation` callback.
The main utility of this callback is to monitor the degree of correlation
in model parameters and optionally gradients over a time span. The idea
is that for optimization trajectories, steps are ideally as de-correlated
as possible (at least within reason), and indeed is actually a major
assumption of Adam-like optimizers.
There is no hard coded heuristic for identifying "too much correlation"
yet, however this callback can help do the data collection for you to
develop a sense for yourself. One method for trying this out is to
set varying learning rates, and seeing how the autocorrelation spectra
are different.
"""

# construct a scalar regression task with SchNet encoder
task = ScalarRegressionTask(
encoder_class=SchNet,
# kwargs to be passed into the creation of SchNet model
encoder_kwargs={
"encoder_only": True,
"hidden_feats": [128, 128, 128],
"atom_embedding_dim": 128,
},
# which keys to use as targets
task_keys=["energy_relaxed"],
log_embeddings=False,
)
# Use IS2RE devset to test workflow
# SchNet uses RBFs, and expects edge features corresponding to atom-atom distances
dm = MatSciMLDataModule.from_devset(
"IS2REDataset",
dset_kwargs={
"transforms": [
PointCloudToGraphTransform(
"dgl",
cutoff_dist=20.0,
node_keys=["pos", "atomic_numbers"],
),
DistancesTransform(),
],
},
)

# tensorboard logging if working purely locally, otherwise wandb
logger = WandbLogger(
name="helper-callback", offline=False, project="matsciml", log_model="all"
)
logger = TensorBoardLogger("./")

# run a quick training loop
trainer = pl.Trainer(max_epochs=30, logger=logger, callbacks=[ModelAutocorrelation()])
trainer.fit(task, datamodule=dm)
57 changes: 57 additions & 0 deletions examples/callbacks/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

from matsciml.datasets.transforms import DistancesTransform, PointCloudToGraphTransform
from matsciml.lightning.data_utils import MatSciMLDataModule
from matsciml.lightning.callbacks import TrainingHelperCallback
from matsciml.models import SchNet
from matsciml.models.base import ScalarRegressionTask

"""
This script demonstrates the use of the ``TrainingHelperCallback``
callback. The purpose of this callback is to provide some
helpful heuristics into the training process by identifying
some common issues like unused weights, small gradients,
and oversmoothed embeddings.
"""

# construct a scalar regression task with SchNet encoder
task = ScalarRegressionTask(
encoder_class=SchNet,
# kwargs to be passed into the creation of SchNet model
encoder_kwargs={
"encoder_only": True,
"hidden_feats": [128, 128, 128],
"atom_embedding_dim": 128,
},
# which keys to use as targets
task_keys=["energy_relaxed"],
log_embeddings=True,
)
# Use IS2RE devset to test workflow
# SchNet uses RBFs, and expects edge features corresponding to atom-atom distances
dm = MatSciMLDataModule.from_devset(
"IS2REDataset",
dset_kwargs={
"transforms": [
PointCloudToGraphTransform(
"dgl",
cutoff_dist=20.0,
node_keys=["pos", "atomic_numbers"],
),
DistancesTransform(),
],
},
)

# tensorboard logging if working purely locally
# logger = TensorBoardLogger("./")
logger = WandbLogger(
name="helper-callback", offline=False, project="matsciml", log_model="all"
)

# run a quick training loop
trainer = pl.Trainer(max_epochs=10, logger=logger, callbacks=[TrainingHelperCallback()])
trainer.fit(task, datamodule=dm)
15 changes: 11 additions & 4 deletions matsciml/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import lmdb
import torch
import numpy as np
from einops import einsum, rearrange
from joblib import Parallel, delayed
from pymatgen.core import Lattice, Structure
Expand Down Expand Up @@ -302,11 +303,11 @@ def get_lmdb_keys(
"Both `ignore_keys` and `_lambda` were passed; arguments are mutually exclusive.",
)
if ignore_keys:
_lambda = lambda x: x not in ignore_keys
_lambda = lambda x: x not in ignore_keys # noqa: E731
else:
if not _lambda:
# escape case where we basically don't filter
_lambda = lambda x: x
_lambda = lambda x: x # noqa: E731
# convert to a sorted list of keys
keys = sorted(list(filter(_lambda, keys)))
return keys
Expand Down Expand Up @@ -529,7 +530,7 @@ def divide_data_chunks(
assert all(
[length != 0 for length in lengths],
), "Too many processes specified and not enough data to split over multiple LMDB files. Decrease `num_procs!`"
p = Parallel(num_procs)(
_ = Parallel(num_procs)(
delayed(write_chunk)(chunk, target_dir, index, metadata)
for chunk, index in zip(chunks, lmdb_indices)
)
Expand Down Expand Up @@ -693,6 +694,11 @@ def calculate_periodic_shifts(
include_index=True,
include_image=True,
)
# check to make sure the cell definition is valid
if np.any(structure.frac_coords > 1.0):
raise ValueError(
f"Structure has fractional coordinates greater than 1! Check structure:\n{structure}"
)

def _all_sites_have_neighbors(neighbors):
return all([len(n) for n in neighbors])
Expand Down Expand Up @@ -729,12 +735,13 @@ def _all_sites_have_neighbors(neighbors):
cell = torch.from_numpy(cell.copy()).float()
# get coordinates as well, for standardization
frac_coords = torch.from_numpy(structure.frac_coords).float()
coords = torch.from_numpy(structure.cart_coords).float()
return_dict = {
"src_nodes": torch.LongTensor(all_src),
"dst_nodes": torch.LongTensor(all_dst),
"images": torch.FloatTensor(all_images),
"cell": cell,
"pos": frac_coords,
"pos": coords,
}
# now calculate offsets based on each image for a lattice
return_dict["offsets"] = einsum(return_dict["images"], cell, "v i, n i j -> v j")
Expand Down
Loading

0 comments on commit 0e3a640

Please sign in to comment.