Skip to content

Commit

Permalink
Merge pull request #442 from mir-group/develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Linux-cpp-lisp authored Jul 9, 2024
2 parents 3db6964 + 69385ab commit d3a7763
Show file tree
Hide file tree
Showing 35 changed files with 520 additions and 123 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: Black Check
uses: psf/black@stable
with:
version: "22.3.0"
version: "24.4.2"

flake8:
runs-on: ubuntu-latest
Expand All @@ -29,7 +29,7 @@ jobs:
python-version: '3.x'
- name: Install flake8
run: |
pip install flake8==7.0.0
pip install flake8==7.1.0
- name: run flake8
run: |
flake8 . --count --show-source --statistics
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install setuptools wheel
if [ ${TORCH} = "1.13.1" ]; then pip install numpy==1.*; fi # older torch versions fail with numpy 2
pip install torch==${TORCH} -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install h5py scikit-learn # install packages that aren't required dependencies but that the tests do need
pip install --upgrade-strategy only-if-needed .
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ fail_fast: true

repos:
- repo: https://github.com/psf/black
rev: 22.3.0
rev: 24.4.2
hooks:
- id: black

- repo: https://gitlab.com/pycqa/flake8
rev: 4.0.1
- repo: https://github.com/pycqa/flake8
rev: 7.1.0
hooks:
- id: flake8
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

Most recent change on the bottom.


## Unreleased


## [0.6.1] - 2024-7-9
### Added
- add support for equivariance testing of arbitrary Cartesian tensor outputs
- [Breaking] use entry points for `nequip.extension`s (e.g. for field registration)
- alternate neighborlist support enabled with `NEQUIP_NL` environment variable, which can be set to `ase` (default), `matscipy` or `vesin`
- Allow `n_train` and `n_val` to be specified as percentages of datasets.
- Only attempt training restart if `trainer.pth` file present (prevents unnecessary crashes due to file-not-found errors in some cases)

### Changed
- [Breaking] `NEQUIP_MATSCIPY_NL` environment variable no longer supported

### Fixed
- Fixed `flake8` install location in `pre-commit-config.yaml`


## [0.6.0] - 2024-5-10
### Added
- add Tensorboard as logger option
Expand Down
3 changes: 3 additions & 0 deletions configs/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ save_ema_checkpoint_freq: -1
# training
n_train: 100 # number of training data
n_val: 50 # number of validation data
# alternatively, n_train and n_val can be set as percentages of the dataset size:
# n_train: 70% # 70% of dataset
# n_val: 30% # 30% of dataset (if validation_dataset not set), or 30% of validation_dataset (if set)
learning_rate: 0.005 # learning rate, we found values between 0.01 and 0.005 to work best - this is often one of the most important hyperparameters to tune
batch_size: 5 # batch size, we found it important to keep this small for most applications including forces (1-5); for energy-only training, higher batch sizes work better
validation_batch_size: 10 # batch size for evaluating the model during validation. This does not affect the training results, but using the highest value possible (<=n_val) without running out of memory will speed up your training.
Expand Down
37 changes: 37 additions & 0 deletions nequip/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,38 @@
import sys

from ._version import __version__ # noqa: F401

import packaging.version

import torch
import warnings

# torch version checks
torch_version = packaging.version.parse(torch.__version__)

# only allow 1.11*, 1.13* or higher (no 1.12.*)
assert (torch_version == packaging.version.parse("1.11")) or (
torch_version >= packaging.version.parse("1.13")
), f"NequIP supports PyTorch 1.11.* or 1.13.* or later, but {torch_version} found"

# warn if using 1.13* or 2.0.*
if packaging.version.parse("1.13.0") <= torch_version:
warnings.warn(
f"!! PyTorch version {torch_version} found. Upstream issues in PyTorch versions 1.13.* and 2.* have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. The best tested PyTorch version to use with CUDA devices is 1.11; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue."
)


# Load all installed nequip extension packages
# This allows installed extensions to register themselves in
# the nequip infrastructure with calls like `register_fields`

# see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata
if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points

_DISCOVERED_NEQUIP_EXTENSION = entry_points(group="nequip.extension")
for ep in _DISCOVERED_NEQUIP_EXTENSION:
if ep.name == "init_always":
ep.load()
2 changes: 1 addition & 1 deletion nequip/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# See Python packaging guide
# https://packaging.python.org/guides/single-sourcing-package-version/

__version__ = "0.6.0"
__version__ = "0.6.1"
134 changes: 109 additions & 25 deletions nequip/data/AtomicData.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
import os

import numpy as np
import ase.neighborlist
import ase
from ase.calculators.singlepoint import SinglePointCalculator, SinglePointDFTCalculator
from ase.calculators.calculator import all_properties as ase_all_properties
from ase.stress import voigt_6_to_full_3x3_stress, full_3x3_to_voigt_6_stress

import torch
import e3nn.o3
from e3nn.io import CartesianTensor

from . import AtomicDataDict
from ._util import _TORCH_INTEGER_DTYPES
Expand All @@ -26,6 +26,7 @@
# A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case)
PBC = Union[bool, Tuple[bool, bool, bool]]

# === Key Registration ===

_DEFAULT_LONG_FIELDS: Set[str] = {
AtomicDataDict.EDGE_INDEX_KEY,
Expand Down Expand Up @@ -61,17 +62,23 @@
AtomicDataDict.CELL_KEY,
AtomicDataDict.BATCH_PTR_KEY,
}
_DEFAULT_CARTESIAN_TENSOR_FIELDS: Dict[str, str] = {
AtomicDataDict.STRESS_KEY: "ij=ji",
AtomicDataDict.VIRIAL_KEY: "ij=ji",
}
_NODE_FIELDS: Set[str] = set(_DEFAULT_NODE_FIELDS)
_EDGE_FIELDS: Set[str] = set(_DEFAULT_EDGE_FIELDS)
_GRAPH_FIELDS: Set[str] = set(_DEFAULT_GRAPH_FIELDS)
_LONG_FIELDS: Set[str] = set(_DEFAULT_LONG_FIELDS)
_CARTESIAN_TENSOR_FIELDS: Dict[str, str] = dict(_DEFAULT_CARTESIAN_TENSOR_FIELDS)


def register_fields(
node_fields: Sequence[str] = [],
edge_fields: Sequence[str] = [],
graph_fields: Sequence[str] = [],
long_fields: Sequence[str] = [],
cartesian_tensor_fields: Dict[str, str] = {},
) -> None:
r"""Register fields as being per-atom, per-edge, or per-frame.
Expand All @@ -83,18 +90,36 @@ def register_fields(
edge_fields: set = set(edge_fields)
graph_fields: set = set(graph_fields)
long_fields: set = set(long_fields)
allfields = node_fields.union(edge_fields, graph_fields)
assert len(allfields) == len(node_fields) + len(edge_fields) + len(graph_fields)

# error checking: prevents registering fields as contradictory types
# potentially unregistered fields
assert len(node_fields.intersection(edge_fields)) == 0
assert len(node_fields.intersection(graph_fields)) == 0
assert len(edge_fields.intersection(graph_fields)) == 0
# already registered fields
assert len(_NODE_FIELDS.intersection(edge_fields)) == 0
assert len(_NODE_FIELDS.intersection(graph_fields)) == 0
assert len(_EDGE_FIELDS.intersection(node_fields)) == 0
assert len(_EDGE_FIELDS.intersection(graph_fields)) == 0
assert len(_GRAPH_FIELDS.intersection(edge_fields)) == 0
assert len(_GRAPH_FIELDS.intersection(node_fields)) == 0

# check that Cartesian tensor fields to add are rank-2 (higher ranks not supported)
for cart_tensor_key in cartesian_tensor_fields:
cart_tensor_rank = len(
CartesianTensor(cartesian_tensor_fields[cart_tensor_key]).indices
)
if cart_tensor_rank != 2:
raise NotImplementedError(
f"Only rank-2 tensor data processing supported, but got {cart_tensor_key} is rank {cart_tensor_rank}. Consider raising a GitHub issue if higher-rank tensor data processing is desired."
)

# update fields
_NODE_FIELDS.update(node_fields)
_EDGE_FIELDS.update(edge_fields)
_GRAPH_FIELDS.update(graph_fields)
_LONG_FIELDS.update(long_fields)
if len(set.union(_NODE_FIELDS, _EDGE_FIELDS, _GRAPH_FIELDS)) < (
len(_NODE_FIELDS) + len(_EDGE_FIELDS) + len(_GRAPH_FIELDS)
):
raise ValueError(
"At least one key was registered as more than one of node, edge, or graph!"
)
_CARTESIAN_TENSOR_FIELDS.update(cartesian_tensor_fields)


def deregister_fields(*fields: Sequence[str]) -> None:
Expand All @@ -109,9 +134,16 @@ def deregister_fields(*fields: Sequence[str]) -> None:
assert f not in _DEFAULT_NODE_FIELDS, "Cannot deregister built-in field"
assert f not in _DEFAULT_EDGE_FIELDS, "Cannot deregister built-in field"
assert f not in _DEFAULT_GRAPH_FIELDS, "Cannot deregister built-in field"
assert f not in _DEFAULT_LONG_FIELDS, "Cannot deregister built-in field"
assert (
f not in _DEFAULT_CARTESIAN_TENSOR_FIELDS
), "Cannot deregister built-in field"

_NODE_FIELDS.discard(f)
_EDGE_FIELDS.discard(f)
_GRAPH_FIELDS.discard(f)
_LONG_FIELDS.discard(f)
_CARTESIAN_TENSOR_FIELDS.pop(f, None)


def _register_field_prefix(prefix: str) -> None:
Expand All @@ -125,6 +157,9 @@ def _register_field_prefix(prefix: str) -> None:
)


# === AtomicData ===


def _process_dict(kwargs, ignore_fields=[]):
"""Convert a dict of data into correct dtypes/shapes according to key"""
# Deal with _some_ dtype issues
Expand Down Expand Up @@ -449,17 +484,40 @@ def from_ase(
cell = kwargs.pop("cell", atoms.get_cell())
pbc = kwargs.pop("pbc", atoms.pbc)

# handle ASE-style 6 element Voigt order stress
for key in (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY):
if key in add_fields:
if add_fields[key].shape == (3, 3):
# it's already 3x3, do nothing else
pass
elif add_fields[key].shape == (6,):
# it's Voigt order
add_fields[key] = voigt_6_to_full_3x3_stress(add_fields[key])
# IMPORTANT: the following reshape logic only applies to rank-2 Cartesian tensor fields
for key in add_fields:
if key in _CARTESIAN_TENSOR_FIELDS:
# enforce (3, 3) shape for graph fields, e.g. stress, virial
if key in _GRAPH_FIELDS:
# handle ASE-style 6 element Voigt order stress
if key in (AtomicDataDict.STRESS_KEY, AtomicDataDict.VIRIAL_KEY):
if add_fields[key].shape == (6,):
add_fields[key] = voigt_6_to_full_3x3_stress(
add_fields[key]
)
if add_fields[key].shape == (3, 3):
# it's already 3x3, do nothing else
pass
elif add_fields[key].shape == (9,):
add_fields[key] = add_fields[key].reshape((3, 3))
else:
raise RuntimeError(
f"bad shape for {key} registered as a Cartesian tensor graph field---please note that only rank-2 Cartesian tensors are currently supported"
)
# enforce (N_atom, 3, 3) shape for node fields, e.g. Born effective charges
elif key in _NODE_FIELDS:
if add_fields[key].shape[1:] == (3, 3):
pass
elif add_fields[key].shape[1:] == (9,):
add_fields[key] = add_fields[key].reshape((-1, 3, 3))
else:
raise RuntimeError(
f"bad shape for {key} registered as a Cartesian tensor node field---please note that only rank-2 Cartesian tensors are currently supported"
)
else:
raise RuntimeError(f"bad shape for {key}")
raise RuntimeError(
f"{key} registered as a Cartesian tensor field was not registered as either a graph or node field"
)

return cls.from_points(
pos=atoms.positions,
Expand Down Expand Up @@ -705,12 +763,21 @@ def without_nodes(self, which_nodes):
assert _ERROR_ON_NO_EDGES in ("true", "false")
_ERROR_ON_NO_EDGES = _ERROR_ON_NO_EDGES == "true"

_NEQUIP_MATSCIPY_NL: Final[bool] = os.environ.get("NEQUIP_MATSCIPY_NL", "false").lower()
assert _NEQUIP_MATSCIPY_NL in ("true", "false")
_NEQUIP_MATSCIPY_NL = _NEQUIP_MATSCIPY_NL == "true"
# use "ase" as default
# TODO: eventually, choose fastest as default
# NOTE:
# - vesin and matscipy do not support self-interaction
# - vesin does not allow for mixed pbcs
_NEQUIP_NL: Final[str] = os.environ.get("NEQUIP_NL", "ase").lower()

if _NEQUIP_MATSCIPY_NL:
if _NEQUIP_NL == "vesin":
from vesin import NeighborList as vesin_nl
elif _NEQUIP_NL == "matscipy":
import matscipy.neighbours
elif _NEQUIP_NL == "ase":
import ase.neighborlist
else:
raise NotImplementedError(f"Unknown neighborlist NEQUIP_NL = {_NEQUIP_NL}")


def neighbor_list_and_relative_vec(
Expand Down Expand Up @@ -790,7 +857,24 @@ def neighbor_list_and_relative_vec(
# ASE dependent part
temp_cell = ase.geometry.complete_cell(temp_cell)

if _NEQUIP_MATSCIPY_NL:
if _NEQUIP_NL == "vesin":
assert strict_self_interaction and not self_interaction
# use same mixed pbc logic as
# https://github.com/Luthaf/vesin/blob/main/python/vesin/src/vesin/_ase.py
if pbc[0] and pbc[1] and pbc[2]:
periodic = True
elif not pbc[0] and not pbc[1] and not pbc[2]:
periodic = False
else:
raise ValueError(
"different periodic boundary conditions on different axes are not supported by vesin neighborlist, use ASE or matscipy"
)

first_idex, second_idex, shifts = vesin_nl(
cutoff=float(r_max), full_list=True
).compute(points=temp_pos, box=temp_cell, periodic=periodic, quantities="ijS")

elif _NEQUIP_NL == "matscipy":
assert strict_self_interaction and not self_interaction
first_idex, second_idex, shifts = matscipy.neighbours.neighbour_list(
"ijS",
Expand All @@ -799,7 +883,7 @@ def neighbor_list_and_relative_vec(
positions=temp_pos,
cutoff=float(r_max),
)
else:
elif _NEQUIP_NL == "ase":
first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list(
"ijS",
pbc,
Expand Down
6 changes: 5 additions & 1 deletion nequip/data/AtomicDataDict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Authors: Albert Musaelian
"""

from typing import Dict, Any

import torch
Expand Down Expand Up @@ -67,7 +68,10 @@ def with_edge_vectors(data: Type, with_lengths: bool = True) -> Type:
# (2) works on a Batch constructed from AtomicData
pos = data[_keys.POSITIONS_KEY]
edge_index = data[_keys.EDGE_INDEX_KEY]
edge_vec = pos[edge_index[1]] - pos[edge_index[0]]
# edge_vec = pos[edge_index[1]] - pos[edge_index[0]]
edge_vec = torch.index_select(pos, 0, edge_index[1]) - torch.index_select(
pos, 0, edge_index[0]
)
if _keys.CELL_KEY in data:
# ^ note that to save time we don't check that the edge_cell_shifts are trivial if no cell is provided; we just assume they are either not present or all zero.
# -1 gives a batch dim no matter what
Expand Down
2 changes: 2 additions & 0 deletions nequip/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
_EDGE_FIELDS,
_GRAPH_FIELDS,
_LONG_FIELDS,
_CARTESIAN_TENSOR_FIELDS,
)
from ._dataset import (
AtomicDataset,
Expand Down Expand Up @@ -39,5 +40,6 @@
_EDGE_FIELDS,
_GRAPH_FIELDS,
_LONG_FIELDS,
_CARTESIAN_TENSOR_FIELDS,
EMTTestDataset,
]
Loading

0 comments on commit d3a7763

Please sign in to comment.