Skip to content

Commit

Permalink
bump pre-commit hooks, tweak ruff config, delete debug print statements
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Jan 22, 2024
1 parent e2a2b82 commit 4383857
Show file tree
Hide file tree
Showing 12 changed files with 83 additions and 72 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.8
rev: v0.1.14
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -46,7 +46,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v8.56.0
rev: v9.0.0-alpha.1
hooks:
- id: eslint
types: [file]
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

[![Tests](https://github.com/CederGroupHub/chgnet/actions/workflows/test.yml/badge.svg)](https://github.com/CederGroupHub/chgnet/actions/workflows/test.yml)
[![Codacy Badge](https://app.codacy.com/project/badge/Coverage/e3bdcea0382a495d96408e4f84408e85)](https://app.codacy.com/gh/CederGroupHub/chgnet/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_coverage)
[![arXiv](https://img.shields.io/badge/arXiv-2302.14231-blue)](https://arxiv.org/abs/2302.14231)
![GitHub repo size](https://img.shields.io/github/repo-size/CederGroupHub/chgnet)
[![arXiv](https://img.shields.io/badge/arXiv-2302.14231-blue?logo=arxiv&logoColor=white)](https://arxiv.org/abs/2302.14231)
![GitHub repo size](https://img.shields.io/github/repo-size/CederGroupHub/chgnet?logo=github&logoColor=white&label=Repo%20Size)
[![PyPI](https://img.shields.io/pypi/v/chgnet?logo=pypi&logoColor=white)](https://pypi.org/project/chgnet?logo=pypi&logoColor=white)
[![Docs](https://img.shields.io/badge/API-Docs-blue)](https://chgnet.lbl.gov)
[![Docs](https://img.shields.io/badge/API-Docs-blue?logo=readthedocs&logoColor=white)](https://chgnet.lbl.gov)
[![Requires Python 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)

</h4>
Expand Down
42 changes: 21 additions & 21 deletions examples/fine_tuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
" import chgnet\n",
"except ImportError:\n",
" # install CHGNet (only needed on Google Colab or if you didn't install CHGNet yet)\n",
" !pip install chgnet\n"
" !pip install chgnet"
]
},
{
Expand All @@ -45,7 +45,7 @@
"# If the above line fails in Google Colab due to numpy version issue,\n",
"# please restart the runtime, and the problem will be solved\n",
"\n",
"chgnet = CHGNet.load()\n"
"chgnet = CHGNet.load()"
]
},
{
Expand Down Expand Up @@ -82,7 +82,7 @@
"\n",
"# ./my_vasp_calc_dir contains vasprun.xml OSZICAR etc.\n",
"dataset_dict = parse_vasp_dir(file_root=\"./my_vasp_calc_dir\")\n",
"print(list(dataset_dict))\n"
"print(list(dataset_dict))"
]
},
{
Expand Down Expand Up @@ -135,7 +135,7 @@
"converter = CrystalGraphConverter()\n",
"for idx, struct in enumerate(dataset_dict[\"structure\"]):\n",
" graph = converter(struct)\n",
" graph.save(fname=f\"{idx}.pt\")\n"
" graph.save(fname=f\"{idx}.pt\")"
]
},
{
Expand Down Expand Up @@ -205,7 +205,7 @@
" stresses.append(\n",
" pred[\"s\"] * -10 + np.random.uniform(-0.05, 0.05, size=pred[\"s\"].shape)\n",
" )\n",
" magmoms.append(pred[\"m\"] + np.random.uniform(-0.03, 0.03, size=pred[\"m\"].shape))\n"
" magmoms.append(pred[\"m\"] + np.random.uniform(-0.03, 0.03, size=pred[\"m\"].shape))"
]
},
{
Expand Down Expand Up @@ -233,7 +233,7 @@
"metadata": {},
"outputs": [],
"source": [
"from chgnet.data.dataset import StructureData, get_train_val_test_loader\n"
"from chgnet.data.dataset import StructureData, get_train_val_test_loader"
]
},
{
Expand All @@ -260,7 +260,7 @@
")\n",
"train_loader, val_loader, test_loader = get_train_val_test_loader(\n",
" dataset, batch_size=8, train_ratio=0.9, val_ratio=0.05\n",
")\n"
")"
]
},
{
Expand Down Expand Up @@ -301,7 +301,7 @@
"from chgnet.trainer import Trainer\n",
"\n",
"# Load pretrained CHGNet\n",
"chgnet = CHGNet.load()\n"
"chgnet = CHGNet.load()"
]
},
{
Expand Down Expand Up @@ -331,7 +331,7 @@
" chgnet.angle_layers,\n",
"]:\n",
" for param in layer.parameters():\n",
" param.requires_grad = False\n"
" param.requires_grad = False"
]
},
{
Expand All @@ -352,7 +352,7 @@
" learning_rate=1e-2,\n",
" use_device=\"cpu\",\n",
" print_freq=6,\n",
")\n"
")"
]
},
{
Expand Down Expand Up @@ -401,7 +401,7 @@
}
],
"source": [
"trainer.train(train_loader, val_loader, test_loader)\n"
"trainer.train(train_loader, val_loader, test_loader)"
]
},
{
Expand All @@ -420,7 +420,7 @@
"outputs": [],
"source": [
"model = trainer.model\n",
"best_model = trainer.best_model # best model based on validation energy MAE\n"
"best_model = trainer.best_model # best model based on validation energy MAE"
]
},
{
Expand Down Expand Up @@ -465,7 +465,7 @@
"# Imagine this is the VASP raw energy\n",
"vasp_raw_energy = -58.97\n",
"\n",
"print(f\"The raw total energy from VASP of LMO is: {vasp_raw_energy} eV\")\n"
"print(f\"The raw total energy from VASP of LMO is: {vasp_raw_energy} eV\")"
]
},
{
Expand Down Expand Up @@ -511,7 +511,7 @@
"corrected_energy = (\n",
" vasp_raw_energy + num_Mn * Mn_correction_in_TMO + num_O * oxide_correction\n",
")\n",
"print(f\"The corrected total energy after MP2020 = {corrected_energy:.4} eV\")\n"
"print(f\"The corrected total energy after MP2020 = {corrected_energy:.4} eV\")"
]
},
{
Expand Down Expand Up @@ -547,7 +547,7 @@
"MaterialsProject2020Compatibility(check_potcar=False).process_entries(cse)\n",
"print(\n",
" f\"The total energy of LMO after MP2020Compatibility correction = {cse.energy:.4} eV\"\n",
")\n"
")"
]
},
{
Expand Down Expand Up @@ -626,7 +626,7 @@
}
],
"source": [
"trainer.train(train_loader, val_loader, test_loader, train_composition_model=True)\n"
"trainer.train(train_loader, val_loader, test_loader, train_composition_model=True)"
]
},
{
Expand Down Expand Up @@ -669,7 +669,7 @@
"source": [
"print(\"The pretrained Atom_Ref (per atom reference energy):\")\n",
"for param in chgnet.composition_model.parameters():\n",
" print(param)\n"
" print(param)"
]
},
{
Expand Down Expand Up @@ -700,7 +700,7 @@
"]\n",
"\n",
"# A list of energy_per_atom values (random values here)\n",
"energies_per_atom = [5.5, 6, 4.8, 5.6]\n"
"energies_per_atom = [5.5, 6, 4.8, 5.6]"
]
},
{
Expand All @@ -725,7 +725,7 @@
"new_atom_ref = AtomRef(is_intensive=True)\n",
"new_atom_ref.initialize_from_MPtrj()\n",
"for param in new_atom_ref.parameters():\n",
" print(param[:, :3])\n"
" print(param[:, :3])"
]
},
{
Expand Down Expand Up @@ -769,7 +769,7 @@
"new_atom_ref.fit(structures, energies_per_atom)\n",
"print(\"After refitting, the AtomRef looks like:\")\n",
"for param in new_atom_ref.parameters():\n",
" print(param)\n"
" print(param)"
]
}
],
Expand All @@ -789,7 +789,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions examples/make_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def make_graphs(

def make_one_graph(mp_id: str, graph_id: str, data, graph_dir) -> dict | bool:
"""Convert a structure to a CrystalGraph and save it."""
dic = data.data[mp_id].pop(graph_id)
struct = Structure.from_dict(dic.pop("structure"))
dct = data.data[mp_id].pop(graph_id)
struct = Structure.from_dict(dct.pop("structure"))
try:
graph = data.graph_converter(struct, graph_id=graph_id, mp_id=mp_id)
torch.save(graph, os.path.join(graph_dir, f"{graph_id}.pt"))
return dic
return dct
except Exception:
return False

Expand Down
14 changes: 12 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ select = [
"RUF", # Ruff-specific rules
"SIM", # flake8-simplify
"SLOT", # flakes8-slot
"T201",
"TCH", # flake8-type-checking
"TID", # tidy imports
"TID", # flake8-tidy-imports
Expand All @@ -82,25 +83,34 @@ select = [
"YTT", # flake8-2020
]
ignore = [
"ANN001", # TODO add missing type annotations
"ANN101", # Missing type annotation for self in method
"B019", # Use of functools.lru_cache on methods can lead to memory leaks
"C408", # unnecessary-collection-call
"COM812", # trailing comma missing
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
"D205", # 1 blank line required between summary line and description
"FBT001", # Boolean positional argument in function
"FBT002", # Boolean keyword argument in function
"NPY002", # TODO replace legacy np.random.seed
"PLR", # pylint refactor
"PLW2901", # Outer for loop variable overwritten by inner assignment target
"PT006", # pytest-parametrize-names-wrong-type
"PT011", # pytest-raises-too-broad
"PT013", # pytest-incorrect-pytest-import
"PT019", # pytest-fixture-param-without-value
"PTH", # prefer Path to os.path
]
pydocstyle.convention = "google"
isort.required-imports = ["from __future__ import annotations"]
isort.split-on-trailing-comma = false

[tool.ruff.per-file-ignores]
"tests/*" = ["D103"]
"examples/*" = ["E402", "I002"] # E402 Module level import not at top of file
"tests/*" = ["ANN201", "D103", "INP001", "S101"]
# E402 Module level import not at top of file
"examples/*" = ["E402", "I002", "T201"]
"chgnet/**/*" = ["T201"]
"__init__.py" = ["F401"]


Expand Down
4 changes: 2 additions & 2 deletions site/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
"sourceMap": true,

"forceConsistentCasingInFileNames": true,
"resolveJsonModule": true
}
"resolveJsonModule": true,
},
}
12 changes: 8 additions & 4 deletions tests/test_converter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Literal

import pytest
from pymatgen.core import Lattice, Structure
from pytest import CaptureFixture
Expand All @@ -14,7 +16,7 @@


@pytest.fixture()
def _set_make_graph():
def _set_make_graph() -> None:
# fixture to force make_graph to be None and then restore it after test
from chgnet.graph import converter

Expand All @@ -27,7 +29,9 @@ def _set_make_graph():
@pytest.mark.parametrize(
"atom_graph_cutoff, bond_graph_cutoff", [(5, 3), (5, None), (4, 2)]
)
def test_crystal_graph_converter_cutoff(atom_graph_cutoff, bond_graph_cutoff):
def test_crystal_graph_converter_cutoff(
atom_graph_cutoff: float | None, bond_graph_cutoff: float | None
):
converter = CrystalGraphConverter(
atom_graph_cutoff=atom_graph_cutoff, bond_graph_cutoff=bond_graph_cutoff
)
Expand All @@ -36,14 +40,14 @@ def test_crystal_graph_converter_cutoff(atom_graph_cutoff, bond_graph_cutoff):


@pytest.mark.parametrize("algorithm", ["legacy", "fast"])
def test_crystal_graph_converter_algorithm(algorithm):
def test_crystal_graph_converter_algorithm(algorithm: Literal["legacy", "fast"]):
converter = CrystalGraphConverter(
atom_graph_cutoff=5, bond_graph_cutoff=3, algorithm=algorithm
)
assert converter.algorithm == algorithm


def test_crystal_graph_converter_warns(_set_make_graph):
def test_crystal_graph_converter_warns(_set_make_graph: None):
with pytest.warns(UserWarning, match="Unknown algorithm='foobar', using `legacy`"):
CrystalGraphConverter(
atom_graph_cutoff=5, bond_graph_cutoff=3, algorithm="foobar"
Expand Down
Loading

0 comments on commit 4383857

Please sign in to comment.