From 4383857972806ace6a162ae1dc92c70f982d3a74 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 22 Jan 2024 17:42:59 +0100 Subject: [PATCH] bump pre-commit hooks, tweak ruff config, delete debug print statements --- .pre-commit-config.yaml | 4 ++-- README.md | 6 +++--- examples/fine_tuning.ipynb | 42 ++++++++++++++++++------------------- examples/make_graphs.py | 6 +++--- pyproject.toml | 14 +++++++++++-- site/tsconfig.json | 4 ++-- tests/test_converter.py | 12 +++++++---- tests/test_crystal_graph.py | 28 ++++++++++++------------- tests/test_encoders.py | 4 ++-- tests/test_md.py | 3 --- tests/test_model.py | 23 ++++++++------------ tests/test_trainer.py | 9 ++++++-- 12 files changed, 83 insertions(+), 72 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 08f4683d..ad6d3fec 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.8 + rev: v0.1.14 hooks: - id: ruff args: [--fix] @@ -46,7 +46,7 @@ repos: - svelte - repo: https://github.com/pre-commit/mirrors-eslint - rev: v8.56.0 + rev: v9.0.0-alpha.1 hooks: - id: eslint types: [file] diff --git a/README.md b/README.md index 2b9790ae..529c7109 100755 --- a/README.md +++ b/README.md @@ -4,10 +4,10 @@ [![Tests](https://github.com/CederGroupHub/chgnet/actions/workflows/test.yml/badge.svg)](https://github.com/CederGroupHub/chgnet/actions/workflows/test.yml) [![Codacy Badge](https://app.codacy.com/project/badge/Coverage/e3bdcea0382a495d96408e4f84408e85)](https://app.codacy.com/gh/CederGroupHub/chgnet/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_coverage) -[![arXiv](https://img.shields.io/badge/arXiv-2302.14231-blue)](https://arxiv.org/abs/2302.14231) -![GitHub repo size](https://img.shields.io/github/repo-size/CederGroupHub/chgnet) +[![arXiv](https://img.shields.io/badge/arXiv-2302.14231-blue?logo=arxiv&logoColor=white)](https://arxiv.org/abs/2302.14231) +![GitHub repo size](https://img.shields.io/github/repo-size/CederGroupHub/chgnet?logo=github&logoColor=white&label=Repo%20Size) [![PyPI](https://img.shields.io/pypi/v/chgnet?logo=pypi&logoColor=white)](https://pypi.org/project/chgnet?logo=pypi&logoColor=white) -[![Docs](https://img.shields.io/badge/API-Docs-blue)](https://chgnet.lbl.gov) +[![Docs](https://img.shields.io/badge/API-Docs-blue?logo=readthedocs&logoColor=white)](https://chgnet.lbl.gov) [![Requires Python 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads) diff --git a/examples/fine_tuning.ipynb b/examples/fine_tuning.ipynb index 54e6754f..b844231f 100644 --- a/examples/fine_tuning.ipynb +++ b/examples/fine_tuning.ipynb @@ -19,7 +19,7 @@ " import chgnet\n", "except ImportError:\n", " # install CHGNet (only needed on Google Colab or if you didn't install CHGNet yet)\n", - " !pip install chgnet\n" + " !pip install chgnet" ] }, { @@ -45,7 +45,7 @@ "# If the above line fails in Google Colab due to numpy version issue,\n", "# please restart the runtime, and the problem will be solved\n", "\n", - "chgnet = CHGNet.load()\n" + "chgnet = CHGNet.load()" ] }, { @@ -82,7 +82,7 @@ "\n", "# ./my_vasp_calc_dir contains vasprun.xml OSZICAR etc.\n", "dataset_dict = parse_vasp_dir(file_root=\"./my_vasp_calc_dir\")\n", - "print(list(dataset_dict))\n" + "print(list(dataset_dict))" ] }, { @@ -135,7 +135,7 @@ "converter = CrystalGraphConverter()\n", "for idx, struct in enumerate(dataset_dict[\"structure\"]):\n", " graph = converter(struct)\n", - " graph.save(fname=f\"{idx}.pt\")\n" + " graph.save(fname=f\"{idx}.pt\")" ] }, { @@ -205,7 +205,7 @@ " stresses.append(\n", " pred[\"s\"] * -10 + np.random.uniform(-0.05, 0.05, size=pred[\"s\"].shape)\n", " )\n", - " magmoms.append(pred[\"m\"] + np.random.uniform(-0.03, 0.03, size=pred[\"m\"].shape))\n" + " magmoms.append(pred[\"m\"] + np.random.uniform(-0.03, 0.03, size=pred[\"m\"].shape))" ] }, { @@ -233,7 +233,7 @@ "metadata": {}, "outputs": [], "source": [ - "from chgnet.data.dataset import StructureData, get_train_val_test_loader\n" + "from chgnet.data.dataset import StructureData, get_train_val_test_loader" ] }, { @@ -260,7 +260,7 @@ ")\n", "train_loader, val_loader, test_loader = get_train_val_test_loader(\n", " dataset, batch_size=8, train_ratio=0.9, val_ratio=0.05\n", - ")\n" + ")" ] }, { @@ -301,7 +301,7 @@ "from chgnet.trainer import Trainer\n", "\n", "# Load pretrained CHGNet\n", - "chgnet = CHGNet.load()\n" + "chgnet = CHGNet.load()" ] }, { @@ -331,7 +331,7 @@ " chgnet.angle_layers,\n", "]:\n", " for param in layer.parameters():\n", - " param.requires_grad = False\n" + " param.requires_grad = False" ] }, { @@ -352,7 +352,7 @@ " learning_rate=1e-2,\n", " use_device=\"cpu\",\n", " print_freq=6,\n", - ")\n" + ")" ] }, { @@ -401,7 +401,7 @@ } ], "source": [ - "trainer.train(train_loader, val_loader, test_loader)\n" + "trainer.train(train_loader, val_loader, test_loader)" ] }, { @@ -420,7 +420,7 @@ "outputs": [], "source": [ "model = trainer.model\n", - "best_model = trainer.best_model # best model based on validation energy MAE\n" + "best_model = trainer.best_model # best model based on validation energy MAE" ] }, { @@ -465,7 +465,7 @@ "# Imagine this is the VASP raw energy\n", "vasp_raw_energy = -58.97\n", "\n", - "print(f\"The raw total energy from VASP of LMO is: {vasp_raw_energy} eV\")\n" + "print(f\"The raw total energy from VASP of LMO is: {vasp_raw_energy} eV\")" ] }, { @@ -511,7 +511,7 @@ "corrected_energy = (\n", " vasp_raw_energy + num_Mn * Mn_correction_in_TMO + num_O * oxide_correction\n", ")\n", - "print(f\"The corrected total energy after MP2020 = {corrected_energy:.4} eV\")\n" + "print(f\"The corrected total energy after MP2020 = {corrected_energy:.4} eV\")" ] }, { @@ -547,7 +547,7 @@ "MaterialsProject2020Compatibility(check_potcar=False).process_entries(cse)\n", "print(\n", " f\"The total energy of LMO after MP2020Compatibility correction = {cse.energy:.4} eV\"\n", - ")\n" + ")" ] }, { @@ -626,7 +626,7 @@ } ], "source": [ - "trainer.train(train_loader, val_loader, test_loader, train_composition_model=True)\n" + "trainer.train(train_loader, val_loader, test_loader, train_composition_model=True)" ] }, { @@ -669,7 +669,7 @@ "source": [ "print(\"The pretrained Atom_Ref (per atom reference energy):\")\n", "for param in chgnet.composition_model.parameters():\n", - " print(param)\n" + " print(param)" ] }, { @@ -700,7 +700,7 @@ "]\n", "\n", "# A list of energy_per_atom values (random values here)\n", - "energies_per_atom = [5.5, 6, 4.8, 5.6]\n" + "energies_per_atom = [5.5, 6, 4.8, 5.6]" ] }, { @@ -725,7 +725,7 @@ "new_atom_ref = AtomRef(is_intensive=True)\n", "new_atom_ref.initialize_from_MPtrj()\n", "for param in new_atom_ref.parameters():\n", - " print(param[:, :3])\n" + " print(param[:, :3])" ] }, { @@ -769,7 +769,7 @@ "new_atom_ref.fit(structures, energies_per_atom)\n", "print(\"After refitting, the AtomRef looks like:\")\n", "for param in new_atom_ref.parameters():\n", - " print(param)\n" + " print(param)" ] } ], @@ -789,7 +789,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/examples/make_graphs.py b/examples/make_graphs.py index 3fc7996b..945d20b9 100644 --- a/examples/make_graphs.py +++ b/examples/make_graphs.py @@ -59,12 +59,12 @@ def make_graphs( def make_one_graph(mp_id: str, graph_id: str, data, graph_dir) -> dict | bool: """Convert a structure to a CrystalGraph and save it.""" - dic = data.data[mp_id].pop(graph_id) - struct = Structure.from_dict(dic.pop("structure")) + dct = data.data[mp_id].pop(graph_id) + struct = Structure.from_dict(dct.pop("structure")) try: graph = data.graph_converter(struct, graph_id=graph_id, mp_id=mp_id) torch.save(graph, os.path.join(graph_dir, f"{graph_id}.pt")) - return dic + return dct except Exception: return False diff --git a/pyproject.toml b/pyproject.toml index 7e8c7c4d..d0ecbaa7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ select = [ "RUF", # Ruff-specific rules "SIM", # flake8-simplify "SLOT", # flakes8-slot + "T201", "TCH", # flake8-type-checking "TID", # tidy imports "TID", # flake8-tidy-imports @@ -82,25 +83,34 @@ select = [ "YTT", # flake8-2020 ] ignore = [ + "ANN001", # TODO add missing type annotations + "ANN101", # Missing type annotation for self in method "B019", # Use of functools.lru_cache on methods can lead to memory leaks "C408", # unnecessary-collection-call + "COM812", # trailing comma missing "D100", # Missing docstring in public module "D104", # Missing docstring in public package "D205", # 1 blank line required between summary line and description + "FBT001", # Boolean positional argument in function + "FBT002", # Boolean keyword argument in function + "NPY002", # TODO replace legacy np.random.seed "PLR", # pylint refactor "PLW2901", # Outer for loop variable overwritten by inner assignment target "PT006", # pytest-parametrize-names-wrong-type "PT011", # pytest-raises-too-broad "PT013", # pytest-incorrect-pytest-import "PT019", # pytest-fixture-param-without-value + "PTH", # prefer Path to os.path ] pydocstyle.convention = "google" isort.required-imports = ["from __future__ import annotations"] isort.split-on-trailing-comma = false [tool.ruff.per-file-ignores] -"tests/*" = ["D103"] -"examples/*" = ["E402", "I002"] # E402 Module level import not at top of file +"tests/*" = ["ANN201", "D103", "INP001", "S101"] +# E402 Module level import not at top of file +"examples/*" = ["E402", "I002", "T201"] +"chgnet/**/*" = ["T201"] "__init__.py" = ["F401"] diff --git a/site/tsconfig.json b/site/tsconfig.json index 6b5192c8..c10cdc96 100644 --- a/site/tsconfig.json +++ b/site/tsconfig.json @@ -11,6 +11,6 @@ "sourceMap": true, "forceConsistentCasingInFileNames": true, - "resolveJsonModule": true - } + "resolveJsonModule": true, + }, } diff --git a/tests/test_converter.py b/tests/test_converter.py index ecb9eaeb..eec38b18 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Literal + import pytest from pymatgen.core import Lattice, Structure from pytest import CaptureFixture @@ -14,7 +16,7 @@ @pytest.fixture() -def _set_make_graph(): +def _set_make_graph() -> None: # fixture to force make_graph to be None and then restore it after test from chgnet.graph import converter @@ -27,7 +29,9 @@ def _set_make_graph(): @pytest.mark.parametrize( "atom_graph_cutoff, bond_graph_cutoff", [(5, 3), (5, None), (4, 2)] ) -def test_crystal_graph_converter_cutoff(atom_graph_cutoff, bond_graph_cutoff): +def test_crystal_graph_converter_cutoff( + atom_graph_cutoff: float | None, bond_graph_cutoff: float | None +): converter = CrystalGraphConverter( atom_graph_cutoff=atom_graph_cutoff, bond_graph_cutoff=bond_graph_cutoff ) @@ -36,14 +40,14 @@ def test_crystal_graph_converter_cutoff(atom_graph_cutoff, bond_graph_cutoff): @pytest.mark.parametrize("algorithm", ["legacy", "fast"]) -def test_crystal_graph_converter_algorithm(algorithm): +def test_crystal_graph_converter_algorithm(algorithm: Literal["legacy", "fast"]): converter = CrystalGraphConverter( atom_graph_cutoff=5, bond_graph_cutoff=3, algorithm=algorithm ) assert converter.algorithm == algorithm -def test_crystal_graph_converter_warns(_set_make_graph): +def test_crystal_graph_converter_warns(_set_make_graph: None): with pytest.warns(UserWarning, match="Unknown algorithm='foobar', using `legacy`"): CrystalGraphConverter( atom_graph_cutoff=5, bond_graph_cutoff=3, algorithm="foobar" diff --git a/tests/test_crystal_graph.py b/tests/test_crystal_graph.py index d24d139a..8834d235 100644 --- a/tests/test_crystal_graph.py +++ b/tests/test_crystal_graph.py @@ -24,7 +24,7 @@ def test_crystal_graph_legacy(): assert converter_legacy.algorithm == "legacy" start = perf_counter() graph = converter_legacy(structure) - print("Legacy test_crystal_graph time:", perf_counter() - start) + print("Legacy test_crystal_graph time:", perf_counter() - start) # noqa: T201 assert graph.composition == "Li2 Mn2 O4" assert graph.atomic_number.tolist() == [3, 3, 25, 25, 8, 8, 8, 8] @@ -50,7 +50,7 @@ def test_crystal_graph_fast(): assert converter_fast.algorithm == "fast" start = perf_counter() graph = converter_fast(structure) - print("Fast test_crystal_graph time:", perf_counter() - start) + print("Fast test_crystal_graph time:", perf_counter() - start) # noqa: T201 assert graph.composition == "Li2 Mn2 O4" assert graph.atomic_number.tolist() == [3, 3, 25, 25, 8, 8, 8, 8] @@ -80,7 +80,7 @@ def test_crystal_graph_different_cutoff_legacy(): start = perf_counter() graph = converter_legacy_2(structure) - print("Legacy test_crystal_graph_different_cutoff time:", perf_counter() - start) + print("Legacy test_crystal_graph_different_cutoff time:", perf_counter() - start) # noqa: T201 assert list(graph.atom_frac_coord.shape) == [8, 3] assert list(graph.atom_graph.shape) == [624, 2] @@ -107,7 +107,7 @@ def test_crystal_graph_different_cutoff_fast(): start = perf_counter() graph = converter_fast_2(structure) - print("Fast test_crystal_graph_different_cutoff time:", perf_counter() - start) + print("Fast test_crystal_graph_different_cutoff time:", perf_counter() - start) # noqa: T201 assert list(graph.atom_frac_coord.shape) == [8, 3] assert list(graph.atom_graph.shape) == [624, 2] @@ -133,7 +133,7 @@ def test_crystal_graph_perturb_legacy(): start = perf_counter() graph = converter_legacy(structure_perturbed) - print("Legacy test_crystal_graph_perturb time:", perf_counter() - start) + print("Legacy test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201 assert list(graph.atom_frac_coord.shape) == [8, 3] assert list(graph.atom_graph.shape) == [410, 2] @@ -159,7 +159,7 @@ def test_crystal_graph_perturb_fast(): start = perf_counter() graph = converter_fast(structure_perturbed) - print("Fast test_crystal_graph_perturb time:", perf_counter() - start) + print("Fast test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201 assert list(graph.atom_frac_coord.shape) == [8, 3] assert list(graph.atom_graph.shape) == [410, 2] @@ -184,7 +184,7 @@ def test_crystal_graph_isotropic_strained_legacy(): start = perf_counter() graph = converter_legacy(structure_strained) - print("Legacy test_crystal_graph_isotropic_strained time:", perf_counter() - start) + print("Legacy test_crystal_graph_isotropic_strained time:", perf_counter() - start) # noqa: T201 assert list(graph.atom_frac_coord.shape) == [8, 3] assert list(graph.atom_graph.shape) == [264, 2] @@ -204,7 +204,7 @@ def test_crystal_graph_isotropic_strained_fast(): start = perf_counter() graph = converter_fast(structure_strained) - print("Fast test_crystal_graph_isotropic_strained time:", perf_counter() - start) + print("Fast test_crystal_graph_isotropic_strained time:", perf_counter() - start) # noqa: T201 assert list(graph.atom_frac_coord.shape) == [8, 3] assert list(graph.atom_graph.shape) == [264, 2] @@ -224,7 +224,7 @@ def test_crystal_graph_anisotropic_strained_legacy(): start = perf_counter() graph = converter_legacy(structure_strained) - print( + print( # noqa: T201 "Legacy test_crystal_graph_anisotropic_strained time:", perf_counter() - start ) @@ -246,7 +246,7 @@ def test_crystal_graph_anisotropic_strained_fast(): start = perf_counter() graph = converter_fast(structure_strained) - print("Fast test_crystal_graph_anisotropic_strained time:", perf_counter() - start) + print("Fast test_crystal_graph_anisotropic_strained time:", perf_counter() - start) # noqa: T201 assert list(graph.atom_frac_coord.shape) == [8, 3] assert list(graph.atom_graph.shape) == [336, 2] @@ -265,7 +265,7 @@ def test_crystal_graph_supercell_legacy(): start = perf_counter() graph = converter_legacy(supercell) - print("Legacy test_crystal_graph_supercell time:", perf_counter() - start) + print("Legacy test_crystal_graph_supercell time:", perf_counter() - start) # noqa: T201 assert graph.composition == "Li48 Mn48 O96" assert list(graph.atom_frac_coord.shape) == [192, 3] @@ -290,7 +290,7 @@ def test_crystal_graph_supercell_fast(): start = perf_counter() graph = converter_fast(supercell) - print("Fast test_crystal_graph_supercell time:", perf_counter() - start) + print("Fast test_crystal_graph_supercell time:", perf_counter() - start) # noqa: T201# noqa: T201 assert graph.composition == "Li48 Mn48 O96" assert list(graph.atom_frac_coord.shape) == [192, 3] @@ -323,7 +323,7 @@ def test_crystal_graph_stability_legacy(): graph.directed2undirected.shape[0] == 2 * graph.undirected2directed.shape[0] ) assert graph.atom_graph.shape[0] == graph.directed2undirected.shape[0] - print("Legacy test_crystal_graph_stability time:", total_time) + print("Legacy test_crystal_graph_stability time:", total_time) # noqa: T201 def test_crystal_graph_stability_fast(): @@ -339,7 +339,7 @@ def test_crystal_graph_stability_fast(): graph.directed2undirected.shape[0] == 2 * graph.undirected2directed.shape[0] ) assert graph.atom_graph.shape[0] == graph.directed2undirected.shape[0] - print("Fast test_crystal_graph_stability time:", total_time) + print("Fast test_crystal_graph_stability time:", total_time) # noqa: T201 def test_crystal_graph_repr(): diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 8c7dd0fb..ec66d8bb 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize("atom_feature_dim", [16, 32, 64]) @pytest.mark.parametrize("max_num_elements", [94, 89]) -def test_atom_embedding(atom_feature_dim: int, max_num_elements) -> None: +def test_atom_embedding(atom_feature_dim: int, max_num_elements: int) -> None: atom_embedding = AtomEmbedding(atom_feature_dim, max_num_elements=max_num_elements) atomic_numbers = torch.tensor([6, 7, 8]) @@ -29,7 +29,7 @@ def test_atom_embedding(atom_feature_dim: int, max_num_elements) -> None: @pytest.mark.parametrize("atom_graph_cutoff, bond_graph_cutoff", [(5, 3), (6, 4)]) -def test_bond_encoder(atom_graph_cutoff, bond_graph_cutoff) -> None: +def test_bond_encoder(atom_graph_cutoff: float, bond_graph_cutoff: float) -> None: undirected2directed = torch.tensor([0, 1]) image = torch.zeros((2, 3)) lattice = torch.eye(3) diff --git a/tests/test_md.py b/tests/test_md.py index 74cd4743..7bf22de6 100644 --- a/tests/test_md.py +++ b/tests/test_md.py @@ -123,7 +123,6 @@ def test_md_nve(tmp_path: Path, monkeypatch: MonkeyPatch): assert set(os.listdir()) == {"md_out.log", "md_out.traj"} with open("md_out.log") as log_file: logs = log_file.read() - print("nve logs", logs) assert logs == ( "Time[ps] Etot[eV] Epot[eV] Ekin[eV] T[K]\n" "0.0000 -58.9415 -58.9415 0.0000 0.0\n" @@ -308,8 +307,6 @@ def test_md_crystal_feas_log( assert isinstance(crystal_feas, list) assert len(crystal_feas) == 101 assert len(crystal_feas[0]) == 64 - print(crystal_feas[0][0], crystal_feas[0][1]) - print(crystal_feas[10][0], crystal_feas[10][1]) assert crystal_feas[0][0] == approx(-0.002082636, abs=1e-5) assert crystal_feas[0][1] == approx(-1.4285042, abs=1e-5) assert crystal_feas[10][0] == approx(-0.0020592688, abs=1e-5) diff --git a/tests/test_model.py b/tests/test_model.py index 5c251104..98555274 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -63,15 +63,7 @@ def test_predict_structure() -> None: return_atom_feas=True, return_crystal_feas=True, ) - assert sorted(out) == [ - "atom_fea", - "crystal_fea", - "e", - "f", - "m", - "s", - "site_energies", - ] + assert sorted(out) == ["atom_fea", "crystal_fea", *"efms", "site_energies"] assert out["e"] == pytest.approx(-7.36769, rel=1e-4, abs=1e-4) forces = [ @@ -91,7 +83,6 @@ def test_predict_structure() -> None: [-1.2128221e-06, 2.2305478e-01, -3.2104114e-07], [1.3322200e-06, -8.3219516e-07, -1.0736181e-01], ] - print("stress", stress) assert out["s"] == pytest.approx(np.array(stress), rel=5e-3, abs=1e-4) magmom = [ @@ -154,15 +145,19 @@ def test_predict_structure_rotated(rotation_angle: float, axis: list) -> None: a, b, c = axis_normalized # Compute the skew-symmetric matrix K - K = np.array([[0, -c, b], [c, 0, -a], [-b, a, 0]]) + skew_mat = np.array([[0, -c, b], [c, 0, -a], [-b, a, 0]]) # Compute the rotation matrix using Rodrigues' formula - R = np.eye(3) + np.sin(theta) * K + (1 - np.cos(theta)) * np.dot(K, K) + rot_mat = ( + np.eye(3) + + np.sin(theta) * skew_mat + + (1 - np.cos(theta)) * np.dot(skew_mat, skew_mat) + ) - rotated_force = pristine_prediction["f"] @ R.transpose() + rotated_force = pristine_prediction["f"] @ rot_mat.transpose() assert out["f"] == pytest.approx(rotated_force, rel=1e-3, abs=1e-3) - rotated_stress = R @ pristine_prediction["s"] @ R.transpose() + rotated_stress = rot_mat @ pristine_prediction["s"] @ rot_mat.transpose() assert out["s"] == pytest.approx(rotated_stress, rel=1e-3, abs=1e-3) assert out["m"] == pytest.approx(pristine_prediction["m"], rel=1e-4, abs=1e-4) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index bd93c3c0..5072489a 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np import torch from pymatgen.core import Lattice, Structure @@ -8,6 +10,9 @@ from chgnet.model import CHGNet from chgnet.trainer import Trainer +if TYPE_CHECKING: + from pathlib import Path + lattice = Lattice.cubic(4) species = ["Na", "Cl"] coords = [[0, 0, 0], [0.5, 0.5, 0.5]] @@ -31,7 +36,7 @@ ) -def test_trainer(tmp_path) -> None: +def test_trainer(tmp_path: Path) -> None: chgnet = CHGNet.load() train_loader, val_loader, test_loader = get_train_val_test_loader( data, batch_size=16, train_ratio=0.9, val_ratio=0.05 @@ -59,7 +64,7 @@ def test_trainer(tmp_path) -> None: ), f"Expected 1 {prefix} file, found {n_matches} in {output_files}" -def test_trainer_composition_model(tmp_path) -> None: +def test_trainer_composition_model(tmp_path: Path) -> None: chgnet = CHGNet.load() for param in chgnet.composition_model.parameters(): assert param.requires_grad is False