Skip to content

Commit

Permalink
Add atomization-consistent training and other features/fixes (#123)
Browse files Browse the repository at this point in the history
* Update example

* update imports to base package

* add atomization torch code

* Improve error message

* fix type check in database

* allow bias=None with hierarchical pretraining

* add core code for atomization consistency

* allow value nodes to stay stored as arbitrary python objects

* Improve fallback handling for unique node finder

* update interface to ase and lammps for atomization consistent networks

* add test for atomization consistent learning

* fix issue #121

* fix stray debugging code

* update changelog and documentation

* update example with additional arguments

* recover if lammps installation was not complete

* update default for example

* add example readme

* update example defaults

* Fix bug with custom kernel warnings
  • Loading branch information
lubbersnick authored Jan 22, 2025
1 parent 6dd3a91 commit 78059c7
Show file tree
Hide file tree
Showing 20 changed files with 493 additions and 66 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Breaking changes:
functions are ``load_checkpoint``, ``load_checkpoint_from_cwd``, and
``restore_checkpoint``.
- ``database.make_trainvalidtest_split`` now only takes keyword arguments to
avoid confusions. Use ``make_trainvalidtest_split(test_size=a, valid_size=b)``
avoid confusion. Use ``make_trainvalidtest_split(test_size=a, valid_size=b)``
instead of ``make_trainvalidtest_split(a, b)``.
- Invalid custom kernel specifications are now errors rather than warnings.
- Method of specifying units for custom MD algorithms has changed.
Expand All @@ -31,6 +31,7 @@ New Features:
are compatible with molecular dynamics codes such ASE and LAMMPS.
- Added the ability to weight different systems/atoms/bonds in a loss function.
- Added new function to reload library settings.
- Added atomization-consistent node which exactly constrains their predictions in a dissociated limit.


Improvements:
Expand All @@ -47,6 +48,7 @@ Improvements:
- Improved detection of valid custom kernel implementation.
- Improved computational efficiency of HIP-NN-TS network.
- ``StressForceNode`` now also works with batch size greater than 1.
- Allow testing of splits of arbitrary names using test_model, as long as those splits contain the required variables.


Bug Fixes:
Expand Down
16 changes: 11 additions & 5 deletions docs/source/examples/ase_calculator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ ASE Calculators

Hippynn models can be used with ``ase`` to perform molecular dynamics or other tests.

To build an ASE calculator, you must pass the node associated with energy.
To build an ASE calculator, use the :class:`~hippynn.interfaces.ase_interface.HippynnCalculator` object.
You pass the node associated with energy.
Example::

from hippynn.interfaces.ase_interface import HippynnCalculator
Expand All @@ -13,11 +14,16 @@ Example::

Take note of the ``en_unit`` and ``dist_unit`` parameters for the calculator.
These parameters inform the calculator what units the model consumes and produces for energy and
for distance. If unspecified, the ``en_unit`` is kcal/mol, and the ``dist_unit`` is angstrom.
Whatever units your model uses, the output of the calculator will be in the ``ase`` unit system,
which has energy in eV and distance in Angstroms.
for distance. If unspecified, the ``en_unit`` is kcal/mol (different from ase default of eV!),
and the ``dist_unit`` is angstrom. Whatever units your model uses, the output of the calculator
will be in the ``ase`` unit system, which has energy in eV and distance in Angstroms.

Given an ase atoms object, one can assign the calculator::
If your model only contains one energy node, you can use the function
:func:`~hippynn.interfaces.ase_interface.calculator_from_model`, which automatically identifies
the energy node and from this creates the calculator. This function accepts keyword arguments that will
be passed to the calculator.

Given an ase ``atoms`` object, one can assign the calculator::

atoms.calc = calc

Expand Down
10 changes: 5 additions & 5 deletions docs/source/user_guide/databases.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ Note that input of bond variables for periodic systems can be ill-defined
if there are multiple bonds between the same pairs of atoms. This is not yet
supported.

A note on *cell* variables. The shape of a cell variable should be specified as (n_atoms,3,3).
There are two common conventions for the cell matrix itself; we use the convention that the basis index
comes first, and the cartesian index comes second. That is similar to `ase`,
the [i,j] element of the cell gives the j cartesian coordinate of cell vector i. If you experience
massive difficulties fitting to periodic boundary conditions, you may check the transposed version
A note on *cell* variables. The shape of a cell variable should be specified as (n_systems,3,3), as described above.
It is important to know that there are two common conventions for the cell matrix itself; we use the convention that the basis index
comes first, and the cartesian index comes second. That is, similar to the ``ase`` package,
the element ``cell[sys,i,j]`` gives the ``j`` cartesian coordinate of cell vector ``i`` in system ``sys``. If you experience
massive errors while fitting to periodic boundary conditions, you may check the transposed version
of your cell data, or compute the RDF.

Database Formats and notes
Expand Down
29 changes: 29 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
Example files for HIPPYNN
-------------------------

The files in this directory provide examples for hippynn.
Because hippynn is for training and running machine learning models,
it requires data to run! Our example files typically contain
a header which describes where to get the data in order to run the example.
Occasionally an example file for running a model depends on running a
different example for training that model.


Suggested Starting Points
-------------------------

- If you're looking to get a basic picture of how hippynn works, see
``barebones.py``, which is about as simple of an example as can function.
If you have a basic understanding but want to see more details,
check out the jupyter notebooks ``graph_exploration.ipynb``.

- If you're looking to customize a nicely constructed training script,
look into ``ani1x_training.py``. This trains to a large dataset
known as ani1x, and separates out the various aspects of the
library as well as demonstrating how they interact together.
It has many optional features included.

- If you want to see how to use a model with ase or lammps,
first train a model with ``ani_aluminum_example.py``.
Then check out ``ase_example.py`` or the ``examples/lammps/``
directory respectively.
71 changes: 52 additions & 19 deletions examples/ani1x_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import ase.units


def make_model(network_params,tensor_order):
def make_model(network_params, tensor_order, atomization_consistent):
"""
Build the model graph for energy and potentially force prediction.
"""
Expand All @@ -33,13 +33,18 @@ def make_model(network_params,tensor_order):
species = inputs.SpeciesNode(db_name="atomic_numbers")
positions = inputs.PositionsNode(db_name="coordinates")
network = net_class("hipnn_model", (species, positions), module_kwargs=network_params)
henergy = targets.HEnergyNode("HEnergy", network)

if not atomization_consistent:
henergy = targets.HEnergyNode("HEnergy", network)
else:
henergy = targets.AtomizationEnergyNode("HEnergy", network)

force = physics.GradientNode("forces", (henergy, positions), sign=-1)

return henergy, force


def make_loss(henergy, force,force_training):
def make_loss(henergy, force, force_training):
"""
Build the loss graph for energy and force error.
"""
Expand Down Expand Up @@ -68,13 +73,16 @@ def make_loss(henergy, force,force_training):
return losses


# wb97x-6-31g*, G16. Doesn't need to be exact for most models.
SELF_ENERGY_APPROX = {'C': -37.764142, 'H': -0.4993212, 'N': -54.4628753, 'O': -74.940046}
# wb97x-6-31g*, G16. Doesn't need to be exact for most models, except atomization consistent.
# # # Old values with singlet/triplet multiplicity only
# # SELF_ENERGY_APPROX = {'C': -37.764142, 'H': -0.4993212, 'N': -54.4628753, 'O': -74.940046}
# Recalculated with appropriate vacuum multiplicity
SELF_ENERGY_APPROX = {'C': -37.8338334397, 'H': -0.499321232710, 'N': -54.5732824628, 'O': -75.0424519384}
SELF_ENERGY_APPROX = {k: SELF_ENERGY_APPROX[v] for k, v in zip([6, 1, 7, 8], 'CHNO')}
SELF_ENERGY_APPROX[0] = 0


def load_db(db_info, en_name, force_name, seed, anidata_location, n_workers):
def load_db(db_info, en_name, force_name, seed, anidata_location, n_workers, use_ccx_subset):
"""
Load the database.
"""
Expand All @@ -83,14 +91,19 @@ def load_db(db_info, en_name, force_name, seed, anidata_location, n_workers):

# Ensure total energies loaded in float64.
torch.set_default_dtype(torch.float64)
import os

CCX_EN_NAME = "ccsd(t)_cbs.energy"
if use_ccx_subset:
db_info['targets'].append(CCX_EN_NAME)
database = PyAniFileDB(
file=anidata_location,
species_key='atomic_numbers',
seed=seed,
num_workers=n_workers,
**db_info
)
if en_name != CCX_EN_NAME:
database.targets = [x for x in database.targets if x != CCX_EN_NAME]

# compute (approximate) atomization energy by subtracting self energies
self_energy = np.vectorize(SELF_ENERGY_APPROX.__getitem__)(database.arr_dict['atomic_numbers'])
Expand All @@ -102,12 +115,15 @@ def load_db(db_info, en_name, force_name, seed, anidata_location, n_workers):
if force_name in database.arr_dict:
database.arr_dict[force_name] = database.arr_dict[force_name]*conversion
torch.set_default_dtype(torch.float32)
database.arr_dict['atomic_numbers']=database.arr_dict['atomic_numbers'].astype(np.int64)
database.arr_dict['atomic_numbers'] = database.arr_dict['atomic_numbers'].astype(np.int64)

# Drop indices where computed energy not retrieved.
found_indices = ~np.isnan(database.arr_dict[en_name])
if use_ccx_subset:
filter_name = CCX_EN_NAME
else:
filter_name = en_name
found_indices = ~np.isnan(database.arr_dict[filter_name])
database.arr_dict = {k: v[found_indices] for k, v in database.arr_dict.items()}

database.make_trainvalidtest_split(test_size=0.1, valid_size=0.1)
return database

Expand Down Expand Up @@ -166,7 +182,8 @@ def get_data_names(qm_method, basis_set):

def main(args):
torch.manual_seed(args.seed)
torch.cuda.set_device(args.gpu)
if args.use_gpu:
torch.cuda.set_device(args.gpu)
torch.set_default_dtype(torch.float32)

hippynn.settings.WARN_LOW_DISTANCES = False
Expand All @@ -187,14 +204,16 @@ def main(args):

with hippynn.tools.active_directory(netname):
with hippynn.tools.log_terminal("training_log.txt", 'wt'):
henergy, force = make_model(network_parameters,tensor_order=args.tensor_order)
henergy, force = make_model(network_parameters,
tensor_order=args.tensor_order,
atomization_consistent=args.atomization_consistent)

en_name, force_name = get_data_names(args.qm_method, args.basis_set)

henergy.mol_energy.db_name = en_name
force.db_name = force_name

validation_losses = make_loss(henergy, force,force_training=args.force_training)
validation_losses = make_loss(henergy, force, force_training=args.force_training)

train_loss = validation_losses["LossTotal"]

Expand All @@ -207,17 +226,22 @@ def main(args):
force_name,
n_workers=args.n_workers,
seed=args.seed,
anidata_location=args.anidata_location)
anidata_location=args.anidata_location,
use_ccx_subset=args.use_ccx_subset)

from hippynn.pretraining import hierarchical_energy_initialization

hierarchical_energy_initialization(henergy, database, trainable_after=False)

patience = args.patience
if args.use_ccx_subset:
patience *= 4

setup_params = setup_experiment(training_modules,
device=args.gpu,
batch_size=args.batch_size,
init_lr=args.init_lr,
patience=args.patience,
patience=patience,
max_epochs=args.max_epochs,
stopping_key=args.stopping_key,
)
Expand All @@ -235,29 +259,38 @@ def main(args):

parser.add_argument("--tag", type=str, default="TEST_MODEL_ANI1X", help='name for run')
parser.add_argument("--gpu", type=int, default=0, help='which GPU to run on')
parser.add_argument("--use_gpu", type=bool, default=False, help='Whether to use GPU')

parser.add_argument("--seed", type=int, default=0, help='random seed for init and split')

parser.add_argument("--n_interactions", type=int, default=2)
parser.add_argument("--n_atom_layers", type=int, default=3)
parser.add_argument("--n_features", type=int, default=20)
parser.add_argument("--n_features", type=int, default=128)
parser.add_argument("--n_sensitivities", type=int, default=20)
parser.add_argument("--cutoff_distance", type=float, default=6.5)
parser.add_argument("--lower_cutoff",type=float,default=0.75,
help="Where to initialize the shortest distance sensitivty")
help="Where to initialize the shortest distance sensitivity")
parser.add_argument("--tensor_order",type=int,default=0)
parser.add_argument("--atomization_consistent", type=bool, default=False)

parser.add_argument("--anidata_location", type=str, default='../../../datasets/ani1x_release/ani1x-release.h5')
parser.add_argument("--qm_method", type=str, default='wb97x')
parser.add_argument("--basis_set", type=str, default='dz')

parser.add_argument("--force_training", action='store_true', default=False)
parser.add_argument("--force_training", action='store_true', default=True)

parser.add_argument("--batch_size",type=int, default=1024)
parser.add_argument("--batch_size",type=int, default=256)
parser.add_argument("--init_lr",type=float, default=1e-3)
parser.add_argument("--patience",type=int, default=5)
parser.add_argument("--max_epochs",type=int, default=500)
parser.add_argument("--stopping_key",type=str, default="T-RMSE")

parser.add_argument("--use_ccx_subset",type=bool, default=False,
help="Train only to configurations from the ANI-1ccx subset."
" Note that this will still use the energies using the `qm_method` argument."
" *Note!* This argument will multiply the patience by a factor of 4.")


parser.add_argument("--noprogress", action='store_true', default=False, help='suppress progress bars')
parser.add_argument("--n_workers", type=int, default=2, help='workers for pytorch dataloaders')
args = parser.parse_args()
Expand Down
35 changes: 31 additions & 4 deletions hippynn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from . import _version
__version__ = _version.get_versions()['version']


# Configuration settings
from ._settings_setup import settings, reload_settings

Expand All @@ -30,7 +29,14 @@

# Graph abstractions
from . import graphs
from .graphs import nodes, IdxType, GraphModule, Predictor
from .graphs import nodes, IdxType, GraphModule, Predictor, make_ensemble

# Kinds of nodes
from .graphs.nodes import inputs, targets, loss, pairs, physics, indexers, pairs
from .graphs.nodes import networks as network_nodes

from . import pretraining
from .pretraining import hierarchical_energy_initialization

# Database loading
from . import databases
Expand All @@ -55,9 +61,30 @@
del ase
from . import molecular_dynamics
from . import optimizer
from .interfaces import ase_interface

from . import pretraining
from .pretraining import hierarchical_energy_initialization
# Submodules that require pyseqm
try:
import seqm
except ImportError:
pass
else:
del seqm
from .interfaces import pyseqm_interface

# Submodules that require lammps
try:
import lammps
except ImportError:
pass
else:
del lammps
try:
from .interfaces import lammps_interface
except Exception as eee:
import warnings
warnings.warn(f"Lammps interface was not importable due to exception: :{eee}")
del eee, warnings

# The order is adjusted to put functions after objects in the documentation.
_dir = dir()
Expand Down
2 changes: 1 addition & 1 deletion hippynn/custom_kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def set_custom_kernels(active: Union[bool, str] = True) -> str:
active = active.lower()

if active not in _POSSIBLE_CUSTOM_KERNELS:
raise warnings.warn(f"Using non-standard custom kernel implementation: {active}")
warnings.warn(f"Using non-standard custom kernel implementation: {active}")

# Our goal is that this if-block is to handle the cases for values in the range of
# [True, "auto"] and turn them into the suitable actual implementation.
Expand Down
2 changes: 1 addition & 1 deletion hippynn/databases/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def make_explicit_split_bool(self, split_name: str,
:param split_mask: a boolean array for where to split
:return:
"""
if isinstance(split_mask, torch.tensor):
if isinstance(split_mask, torch.Tensor):
split_mask = split_mask.numpy()
if split_mask.dtype != np.bool_:
if not np.isin(split_mask, [0, 1]).all():
Expand Down
17 changes: 15 additions & 2 deletions hippynn/experiment/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
import sys
import collections
import warnings
from dataclasses import dataclass
import copy
import timeit
Expand Down Expand Up @@ -354,10 +355,22 @@ def test_model(database, evaluator, batch_size, when, metric_tracker=None):
if metric_tracker is None:
metric_tracker = MetricTracker(evaluator.loss_names, stopping_key=None)

# Determine splits which are complete and can be evaluated:
evaluatable_splits = []
required_variables = set(database.inputs + database.targets)
for sname, split in database.splits.items():
if all(k in split for k in required_variables):
evaluatable_splits.append(sname)
else:
missing_arrays = set(k for k in required_variables if k not in split)
warnings.warn(f"Database contains split '{sname}' which"
f" cannot be evaluated because it does not contain the"
f" required quantities: {missing_arrays}")

# A little dance to make sure train, valid, test always come first, when present.
basic_splits = ["train", "valid", "test"]
basic_splits = [s for s in basic_splits if s in database.splits]
splits = basic_splits + [s for s in database.splits if s not in basic_splits]
basic_splits = [s for s in basic_splits if s in evaluatable_splits]
splits = basic_splits + [s for s in evaluatable_splits if s not in basic_splits]

evaluation_data = collections.OrderedDict(
(
Expand Down
Loading

0 comments on commit 78059c7

Please sign in to comment.