Skip to content

Commit

Permalink
merged
Browse files Browse the repository at this point in the history
  • Loading branch information
bowen-bd committed Oct 25, 2024
2 parents 1eedc8c + d3e0160 commit 817e21b
Show file tree
Hide file tree
Showing 16 changed files with 64 additions and 62 deletions.
4 changes: 4 additions & 0 deletions .github/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ changelog:
labels: [refactor]
- title: 🧪 Tests
labels: [tests]
- title: 🧹 Linting
labels: [linting]
- title: 🏷️ Static Typing
labels: [types] # as in static typing
- title: 💥 Breaking Changes
labels: [breaking]
- title: 🔒 Security Fixes
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-14, windows-latest]
python-version: ["39", "310", "311", "312"]
python-version: ["310", "311", "312"]
runs-on: ${{ matrix.os }}
steps:
- name: Check out repo
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ jobs:
pip install uv
uv pip install -e .[test,logging] --resolution=${{ matrix.version.resolution }} --system
# TODO: remove pin once reverse readline fixed
uv pip install monty==2024.7.12 --system
- name: Run Tests
run: pytest --capture=no --cov --cov-report=xml
env:
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
rev: v0.6.9
hooks:
- id: ruff
args: [--fix]
Expand All @@ -13,7 +13,7 @@ repos:
types_or: [python, jupyter]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-case-conflict
- id: check-symlinks
Expand Down Expand Up @@ -48,7 +48,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.10.0
rev: v9.12.0
hooks:
- id: eslint
types: [file]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
![GitHub repo size](https://img.shields.io/github/repo-size/CederGroupHub/chgnet?logo=github&logoColor=white&label=Repo%20Size)
[![PyPI](https://img.shields.io/pypi/v/chgnet?logo=pypi&logoColor=white)](https://pypi.org/project/chgnet?logo=pypi&logoColor=white)
[![Docs](https://img.shields.io/badge/API-Docs-blue?logo=readthedocs&logoColor=white)](https://chgnet.lbl.gov)
[![Requires Python 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)
[![Requires Python 3.10+](https://img.shields.io/badge/Python-3.10+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)

</h4>

Expand Down
10 changes: 4 additions & 6 deletions chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ def _create_graph_legacy(
Graph data structure used to create Crystal_Graph object
"""
graph = Graph([Node(index=idx) for idx in range(n_atoms)])
for ii, jj, img, dist in zip(center_index, neighbor_index, image, distance):
for ii, jj, img, dist in zip(
center_index, neighbor_index, image, distance, strict=True
):
graph.add_edge(center_index=ii, neighbor_index=jj, image=img, distance=dist)

return graph
Expand Down Expand Up @@ -271,11 +273,7 @@ def set_isolated_atom_response(
"""Set the graph converter's response to isolated atom graph
Args:
on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures
with isolated atoms.
Default = 'error'.
Returns:
None
with isolated atoms. Default = 'error'.
"""
self.on_isolated_atoms = on_isolated_atoms

Expand Down
2 changes: 1 addition & 1 deletion chgnet/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def line_graph_adjacency_list(self, cutoff) -> tuple[list[list[int]], list[int]]
# We will need to find directed edges with center = center1
# and create angles with DE1, then do the same for center2 and DE2
for center, dir_edge in zip(
u_edge.nodes, u_edge.info["directed_edge_index"]
u_edge.nodes, u_edge.info["directed_edge_index"], strict=True
):
for directed_edges in self.nodes[center].neighbors.values():
for directed_edge in directed_edges:
Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/composition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def fit(
composition_feas = torch.zeros([num_data, self.max_num_elements])
e = torch.zeros([num_data])
for index, (structure, energy) in enumerate(
zip(structures_or_graphs, energies)
zip(structures_or_graphs, energies, strict=True)
):
if isinstance(structure, Structure):
atomic_number = torch.tensor(
Expand Down
14 changes: 7 additions & 7 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def __init__(
"""
self.ensemble = ensemble
self.thermostat = thermostat
if isinstance(atoms, (Structure, Molecule)):
if isinstance(atoms, Structure | Molecule):
atoms = AseAtomsAdaptor().get_atoms(atoms)
# atoms = atoms.to_ase_atoms()

Expand Down Expand Up @@ -825,9 +825,6 @@ def fit(
verbose (bool): Whether to print the output of the ASE optimizer.
Default = False
**kwargs: Additional parameters for the optimizer.
Returns:
Bulk Modulus (float)
"""
if isinstance(atoms, Atoms):
atoms = AseAtomsAdaptor.get_structure(atoms)
Expand Down Expand Up @@ -859,15 +856,18 @@ def fit(
self.bm.fit()
self.fitted = True

def get_bulk_modulus(self, unit: str = "eV/A^3") -> float:
def get_bulk_modulus(self, unit: Literal["eV/A^3", "GPa"] = "eV/A^3") -> float:
"""Get the bulk modulus of from the fitted Birch-Murnaghan equation of state.
Args:
unit (str): The unit of bulk modulus. Can be "eV/A^3" or "GPa"
Default = "eV/A^3"
Returns:
Bulk Modulus (float)
float: Bulk Modulus
Raises:
ValueError: If the equation of state is not fitted.
"""
if self.fitted is False:
raise ValueError(
Expand All @@ -877,7 +877,7 @@ def get_bulk_modulus(self, unit: str = "eV/A^3") -> float:
return self.bm.b0
if unit == "GPa":
return self.bm.b0_GPa
raise NotImplementedError("unit has to be eV/A^3 or GPa")
raise ValueError("unit has to be eV/A^3 or GPa")

def get_compressibility(self, unit: str = "A^3/eV") -> float:
"""Get the bulk modulus of from the fitted Birch-Murnaghan equation of state.
Expand Down
3 changes: 2 additions & 1 deletion chgnet/model/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
from collections.abc import Sequence

import torch
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(
find_activation(activation),
]
if len(hidden_dim) != 1:
for h_in, h_out in zip(hidden_dim[0:-1], hidden_dim[1:]):
for h_in, h_out in itertools.pairwise(hidden_dim):
layers.append(nn.Linear(h_in, h_out, bias=bias))
layers.append(find_activation(activation))
layers.append(nn.Dropout(dropout))
Expand Down
16 changes: 12 additions & 4 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,10 @@ def forward(
if return_site_energies and self.composition_model is not None:
site_energy_shifts = self.composition_model.get_site_energies(graphs)
prediction["site_energies"] = [
i + j for i, j in zip(prediction["site_energies"], site_energy_shifts)
i + j
for i, j in zip(
prediction["site_energies"], site_energy_shifts, strict=True
)
]
return prediction

Expand Down Expand Up @@ -437,7 +440,12 @@ def _compute(

# Message Passing
for idx, (atom_layer, bond_layer, angle_layer) in enumerate(
zip(self.atom_conv_layers[:-1], self.bond_conv_layers, self.angle_layers)
zip(
self.atom_conv_layers[:-1],
self.bond_conv_layers,
self.angle_layers,
strict=False,
)
):
# Atom Conv
atom_feas = atom_layer(
Expand Down Expand Up @@ -522,7 +530,7 @@ def _compute(
)
# Convert Stress unit from eV/A^3 to GPa
scale = 1 / g.volumes * 160.21766208
stress = [i * j for i, j in zip(stress, scale)]
stress = [i * j for i, j in zip(stress, scale, strict=False)]
prediction["s"] = stress

# Normalize energy if model is intensive
Expand Down Expand Up @@ -614,7 +622,7 @@ def predict_graph(
m (Tensor) : magnetic moments of sites [num_atoms, 3] in Bohr
magneton mu_B
"""
if not isinstance(graph, (CrystalGraph, Sequence)):
if not isinstance(graph, CrystalGraph | Sequence):
raise TypeError(
f"{type(graph)=} must be CrystalGraph or list of CrystalGraphs"
)
Expand Down
2 changes: 1 addition & 1 deletion chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ def forward(
if "m" in self.target_str:
mag_preds, mag_targets = [], []
m_mae_size = 0
for mag_pred, mag_target in zip(prediction["m"], targets["m"]):
for mag_pred, mag_target in zip(prediction["m"], targets["m"], strict=True):
# exclude structures without magmom labels
if mag_target is not None:
mag_preds.append(mag_pred)
Expand Down
42 changes: 19 additions & 23 deletions chgnet/utils/vasp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from typing import TYPE_CHECKING

from monty.io import reverse_readfile
from monty.io import zopen
from monty.os.path import zpath
from pymatgen.io.vasp.outputs import Oszicar, Vasprun

Expand Down Expand Up @@ -58,13 +58,11 @@ def parse_vasp_dir(
exception_on_bad_xml=False,
)

charge, mag_x, mag_y, mag_z, header, all_lines = [], [], [], [], [], []
charge, mag_x, mag_y, mag_z, header = [], [], [], [], []

for line in reverse_readfile(outcar_path):
clean = line.strip()
all_lines.append(clean)
with zopen(outcar_path, encoding="utf-8") as file:
all_lines = [line.strip() for line in file.readlines()]

all_lines.reverse()
# For single atom systems, VASP doesn't print a total line, so
# reverse parsing is very difficult
# for SOC calculations only
Expand All @@ -79,23 +77,21 @@ def parse_vasp_dir(
if clean.startswith("# of ion"):
header = re.split(r"\s{2,}", clean.strip())
header.pop(0)
else:
m = re.match(r"\s*(\d+)\s+(([\d\.\-]+)\s+)+", clean)
if m:
tokens = [float(token) for token in re.findall(r"[\d\.\-]+", clean)]
tokens.pop(0)
if read_charge:
charge.append(dict(zip(header, tokens)))
elif read_mag_x:
mag_x.append(dict(zip(header, tokens)))
elif read_mag_y:
mag_y.append(dict(zip(header, tokens)))
elif read_mag_z:
mag_z.append(dict(zip(header, tokens)))
elif clean.startswith("tot"):
if ion_step_count == (len(mag_x_all) + 1):
mag_x_all.append(mag_x)
read_charge = read_mag_x = read_mag_y = read_mag_z = False
elif re.match(r"\s*(\d+)\s+(([\d\.\-]+)\s+)+", clean):
tokens = [float(token) for token in re.findall(r"[\d\.\-]+", clean)]
tokens.pop(0)
if read_charge:
charge.append(dict(zip(header, tokens, strict=True)))
elif read_mag_x:
mag_x.append(dict(zip(header, tokens, strict=True)))
elif read_mag_y:
mag_y.append(dict(zip(header, tokens, strict=True)))
elif read_mag_z:
mag_z.append(dict(zip(header, tokens, strict=True)))
elif clean.startswith("tot"):
if ion_step_count == (len(mag_x_all) + 1):
mag_x_all.append(mag_x)
read_charge = read_mag_x = read_mag_y = read_mag_z = False
if clean == "total charge":
read_charge = True
read_mag_x = read_mag_y = read_mag_z = False
Expand Down
4 changes: 2 additions & 2 deletions examples/crystaltoolkit_relax_viewer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@
" coords = trajectory.atom_positions[step]\n",
" structure.lattice = lattice # update structure in place for efficiency\n",
" assert len(structure) == len(coords)\n",
" for site, coord in zip(structure, coords):\n",
" for site, coord in zip(structure, coords, strict=True):\n",
" site.coords = coord\n",
"\n",
" title = make_title(*structure.get_space_group_info())\n",
Expand Down Expand Up @@ -406,7 +406,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion examples/fine_tuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
" from chgnet.model import CHGNet\n",
"except ImportError:\n",
" # install CHGNet (only needed on Google Colab or if you didn't install CHGNet yet)\n",
" !pip install chgnet."
" !pip install chgnet"
]
},
{
Expand Down
12 changes: 5 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
[project]
name = "chgnet"
version = "0.3.8"
version = "0.4.0"
description = "Pretrained Universal Neural Network Potential for Charge-informed Atomistic Modeling"
authors = [{ name = "Bowen Deng", email = "[email protected]" }]
requires-python = ">=3.9"
requires-python = ">=3.10"
readme = "README.md"
license = { text = "Modified BSD" }
dependencies = [
"ase>=3.23.0",
"cython>=3",
# "monty==2024.7.12", # TODO: restore once readline fixed
# "numpy>=1.26", # TODO: remove after test
"numpy>=2.0.0",
"numpy>=1.26",
"nvidia-ml-py3>=7.352.0",
"pymatgen>=2024.9.10",
"torch>=2.4.1",
Expand All @@ -21,7 +19,6 @@ classifiers = [
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
Expand Down Expand Up @@ -55,7 +52,8 @@ requires = ["Cython", "numpy>=2.0.0", "setuptools>=65", "wheel"]
build-backend = "setuptools.build_meta"

[tool.ruff]
target-version = "py39"
target-version = "py310"
output-format = "concise"

[tool.ruff.lint]
select = ["ALL"]
Expand Down

0 comments on commit 817e21b

Please sign in to comment.