diff --git a/docs/source/01_quickstart/getting_started.rst b/docs/source/01_quickstart/getting_started.rst index 6799cf1ca..fcd5d506d 100644 --- a/docs/source/01_quickstart/getting_started.rst +++ b/docs/source/01_quickstart/getting_started.rst @@ -73,40 +73,6 @@ use corresponding getters :meth:`~dxtb.Calculator.get_energy`: We recommend using the getters, as they provide the familiar ASE-like interface. -.. warning:: - - If you supply the **same inputs** to the calculator multiple times with - gradient tracking enabled, you have to reset the calculator in between with - :meth:`~dxtb.Calculator.reset_all`. Otherwise, the gradients will be wrong. - - .. admonition:: Example - :class: toggle - - .. code-block:: python - - import torch - import dxtb - - dd = {"dtype": torch.double, "device": torch.device("cpu")} - - numbers = torch.tensor([3, 1], device=dd["device"]) - positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd) - - calc = dxtb.calculators.GFN1Calculator(numbers, **dd) - - pos = positions.clone().requires_grad_(True) - energy = calc.energy(pos) - (g1,) = torch.autograd.grad(energy, pos) - - # wrong gradients without reset here - calc.reset_all() - - pos = positions.clone().requires_grad_(True) - energy = calc.energy(pos) - (g2,) = torch.autograd.grad(energy, pos) - - assert torch.allclose(g1, g2) - Gradients --------- @@ -146,6 +112,41 @@ The equivalency of the two methods (except for the sign) can be verified by the example `here `_. +.. warning:: + + If you supply the **same inputs** to the calculator multiple times with + gradient tracking enabled, you have to reset the calculator in between with + :meth:`~dxtb.Calculator.reset_all`. Otherwise, the gradients will be wrong. + + .. admonition:: Example + :class: toggle + + .. code-block:: python + + import torch + import dxtb + + dd = {"dtype": torch.double, "device": torch.device("cpu")} + + numbers = torch.tensor([3, 1], device=dd["device"]) + positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd) + + calc = dxtb.calculators.GFN1Calculator(numbers, **dd) + + pos = positions.clone().requires_grad_(True) + energy = calc.energy(pos) + (g1,) = torch.autograd.grad(energy, pos) + + # wrong gradients without reset here + calc.reset_all() + + pos = positions.clone().requires_grad_(True) + energy = calc.energy(pos) + (g2,) = torch.autograd.grad(energy, pos) + + assert torch.allclose(g1, g2) + + More Properties --------------- diff --git a/docs/source/_static/dxtb-favicon.png b/docs/source/_static/dxtb-favicon.png new file mode 100644 index 000000000..8e4006daa Binary files /dev/null and b/docs/source/_static/dxtb-favicon.png differ diff --git a/docs/source/_static/dxtb.png b/docs/source/_static/dxtb.png new file mode 100644 index 000000000..706d74e17 Binary files /dev/null and b/docs/source/_static/dxtb.png differ diff --git a/docs/source/conf.py b/docs/source/conf.py index 99d436df7..beccb6d8a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -47,8 +47,8 @@ html_theme = "sphinx_book_theme" html_title = project -html_logo = "_static/dxtb.svg" -html_favicon = "_static/dxtb-favicon.svg" +html_logo = "_static/dxtb.png" +html_favicon = "_static/dxtb-favicon.png" html_theme_options = { "navigation_with_keys": False, @@ -81,6 +81,7 @@ "python": ("https://docs.python.org/3", None), "tad_dftd3": ("https://tad-dftd3.readthedocs.io/en/latest/", None), "tad_dftd4": ("https://tad-dftd4.readthedocs.io/en/latest/", None), + "tad_libcint": ("https://tad-libcint.readthedocs.io/en/latest/", None), "tad_mctc": ("https://tad-mctc.readthedocs.io/en/latest/", None), "tad_multicharge": ("https://tad-multicharge.readthedocs.io/en/latest/", None), "torch": ("https://pytorch.org/docs/stable/", None), @@ -113,7 +114,9 @@ exclude_patterns = [ # Sometimes sphinx reads its own outputs as inputs! "build/html", + "_build/html", "build/jupyter_execute", + "_build/jupyter_execute", "notebooks/README.md", "README.md", "notebooks/*.md", diff --git a/examples/integrals.py b/examples/integrals.py index b85ea6e49..0295f8922 100644 --- a/examples/integrals.py +++ b/examples/integrals.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Simple integral interface. +Simple integral interface. Can be helpful for testing. """ from pathlib import Path diff --git a/examples/limitation_xitorch.py b/examples/limitation_xitorch.py index b2a80aed6..284261e49 100644 --- a/examples/limitation_xitorch.py +++ b/examples/limitation_xitorch.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Calculating forces for vancomycin via AD. +Example for xitorch's inability to be used together with functorch. """ from pathlib import Path diff --git a/pyproject.toml b/pyproject.toml index acfa4080c..386884583 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ source = ["./src"] omit = [ "./src/dxtb/_src/exlibs/xitorch/*", "./src/dxtb/_src/exlibs/scipy/*", - "./src/dxtb/_src/typing.py", + "./src/dxtb/_src/typing/*", "./src/dxtb/components/*", ] diff --git a/src/dxtb/__version__.py b/src/dxtb/__version__.py index 54fb8e144..46f2e018d 100644 --- a/src/dxtb/__version__.py +++ b/src/dxtb/__version__.py @@ -22,5 +22,5 @@ __all__ = ["__version__", "__tversion__"] -__version__ = "0.0.0" +__version__ = "0.0.1" """Version of ``dxtb`` in semantic versioning.""" diff --git a/src/dxtb/_src/cli/driver.py b/src/dxtb/_src/cli/driver.py index e983f7cff..7276dd50b 100644 --- a/src/dxtb/_src/cli/driver.py +++ b/src/dxtb/_src/cli/driver.py @@ -108,7 +108,7 @@ def _set_attr(self, attr: str) -> int | list[int]: for path in self.base: # use charge (or spin) from file or set to zero if Path(path, FILES[attr]).is_file(): - vals.append(io.read_chrg(Path(path, FILES[attr]))) + vals.append(read.read_chrg_from_path(Path(path, FILES[attr]))) else: vals.append(0) @@ -176,7 +176,7 @@ def singlepoint(self) -> Result | None: numbers = pack(_n) positions = pack(_p) else: - _n, _p = io.read_structure_from_file(args.file[0], args.filetype) + _n, _p = read.read_from_path(args.file[0], args.filetype) numbers = torch.tensor(_n, dtype=torch.long, device=dd["device"]) positions = torch.tensor(_p, **dd) diff --git a/src/dxtb/_src/io/__init__.py b/src/dxtb/_src/io/__init__.py index 13cd3a304..ec9b3f8c0 100644 --- a/src/dxtb/_src/io/__init__.py +++ b/src/dxtb/_src/io/__init__.py @@ -18,15 +18,5 @@ Functions for reading and writing files. """ -from . import read from .handler import * -from .logutils import DEFAULT_LOG_CONFIG from .output import * -from .read import ( - read_chrg, - read_coord, - read_orca_engrad, - read_qcschema, - read_structure_from_file, - read_xyz, -) diff --git a/src/dxtb/_src/io/logutils.py b/src/dxtb/_src/io/logutils.py index e867c07af..5187feb96 100644 --- a/src/dxtb/_src/io/logutils.py +++ b/src/dxtb/_src/io/logutils.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Logging +Logging. """ from __future__ import annotations diff --git a/src/dxtb/_src/io/output/header.py b/src/dxtb/_src/io/output/header.py index 96607aab0..ab4a187aa 100644 --- a/src/dxtb/_src/io/output/header.py +++ b/src/dxtb/_src/io/output/header.py @@ -26,7 +26,7 @@ WIDTH = 70 -def get_header() -> str: +def get_header() -> str: # pragma: no cover logo = [ r" _ _ _ ", r" | | | | | | ", diff --git a/src/dxtb/_src/io/output/info.py b/src/dxtb/_src/io/output/info.py index 48f670d7d..e83476aeb 100644 --- a/src/dxtb/_src/io/output/info.py +++ b/src/dxtb/_src/io/output/info.py @@ -26,6 +26,7 @@ import torch from dxtb.__version__ import __tversion__ +from dxtb._src.typing import Any __all__ = [ "get_mkl_num_threads", @@ -70,7 +71,7 @@ def get_system_info(): } -def get_pytorch_info(): +def get_pytorch_info() -> dict[str, Any]: # pragma: no cover is_cuda = torch.cuda.is_available() backends = [] @@ -129,7 +130,7 @@ def get_pytorch_info(): } -def print_system_info(punit=print): +def print_system_info(punit=print) -> None: # pragma: no cover system_info = get_system_info()["System Information"] pytorch_info = get_pytorch_info()["PyTorch Information"] sep = 17 diff --git a/src/dxtb/_src/io/read.py b/src/dxtb/_src/io/read.py deleted file mode 100644 index 587d68952..000000000 --- a/src/dxtb/_src/io/read.py +++ /dev/null @@ -1,362 +0,0 @@ -# This file is part of dxtb. -# -# SPDX-Identifier: Apache-2.0 -# Copyright (C) 2024 Grimme Group -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -IO: Reading Files -================= - -IO utility for reading files. -""" - -from __future__ import annotations - -from json import loads as json_load -from pathlib import Path - -import torch -from tad_mctc import units -from tad_mctc.data import pse - -from dxtb._src.typing import Any, PathLike - -__all__ = [ - "check_xyz", - "read_structure_from_file", - "read_xyz", - "read_xyz_qm9", - "read_qcschema", - "read_coord", - "read_chrg", - "read_tblite_gfn", - "read_orca_engrad", -] - - -def check_xyz(fp: PathLike, xyz: list[list[float]]) -> list[list[float]]: - """ - Check coordinates of file. Particularly, we check for the last coordinate - being at the origin as this might clash with padding. - - Parameters - ---------- - fp : PathLike - Path to coordinate file. - xyz : list[list[float]] - Coordinates of structure. - - Returns - ------- - list[list[float]] - Coordinates of structure. - - Raises - ------ - ValueError - File is actually empty or last coordinate is at origin. - """ - if len(xyz) == 0: - raise ValueError(f"File '{fp}' is empty.") - elif len(xyz) == 1: - if xyz[-1] == [0.0, 0.0, 0.0]: - xyz[-1][0] = 1.0 - else: - if xyz[-1] == [0.0, 0.0, 0.0]: - raise ValueError( - f"Last coordinate is zero in '{fp}'. This will clash with padding." - ) - return xyz - - -def read_structure_from_file( - file: PathLike, ftype: str | None = None -) -> tuple[list[int], list[list[float]]]: - """ - Helper to read the structure from the given file. - - Parameters - ---------- - file : PathLike - Path of file containing the structure. - ftype : str | None, optional - File type. Defaults to ``None``, i.e., infered from the extension. - - Returns - ------- - tuple[list[int], list[list[float]]] - Lists of atoms and coordinates. - - Raises - ------ - FileNotFoundError - File given does not exist. - NotImplementedError - Reader for specific file type not implemented. - ValueError - Unknown file type. - """ - f = Path(file) - if f.exists() is False: - raise FileNotFoundError(f"File '{f}' not found.") - - if ftype is None: - ftype = f.suffix.lower()[1:] - fname = f.name.lower() - - if ftype in ("xyz", "log"): - numbers, positions = read_xyz(f) - elif ftype == "qm9": - numbers, positions = read_xyz_qm9(f) - elif ftype in ("tmol", "tm", "turbomole") or fname == "coord": - numbers, positions = read_coord(f) - elif ftype in ("mol", "sdf", "gen", "pdb"): - raise NotImplementedError( - f"Filetype '{ftype}' recognized but no reader available." - ) - elif ftype in ("qchem"): - raise NotImplementedError( - f"Filetype '{ftype}' (Q-Chem) recognized but no reader available." - ) - elif ftype in ("poscar", "contcar", "vasp", "crystal") or fname in ( - "poscar", - "contcar", - "vasp", - ): - raise NotImplementedError( - "VASP/CRYSTAL file recognized but no reader available." - ) - elif ftype in ("ein", "gaussian"): - raise NotImplementedError( - f"Filetype '{ftype}' (Gaussian) recognized but no reader available." - ) - elif ftype in ("json", "qcschema"): - numbers, positions = read_qcschema(f) - else: - raise ValueError(f"Unknown filetype '{ftype}' in '{f}'.") - - return numbers, positions - - -def read_xyz(fp: PathLike) -> tuple[list[int], list[list[float]]]: - """ - Read xyz file. - - Parameters - ---------- - fp : PathLike - Path to coordinate file. - - Returns - ------- - tuple[list[int], list[list[float]]] - Lists containing the atomic numbers and coordinates. - """ - atoms = [] - xyz = [] - num_atoms = 0 - - with open(fp, encoding="utf-8") as file: - for line_number, line in enumerate(file): - if line_number == 0: - num_atoms = int(line) - elif line_number == 1: - continue - else: - l = line.strip().split() - atom, x, y, z = l - xyz.append([i * units.AA2AU for i in [float(x), float(y), float(z)]]) - atoms.append(pse.S2Z[atom.title()]) - - if len(xyz) != num_atoms: - raise ValueError(f"Number of atoms in {fp} does not match.") - - xyz = check_xyz(fp, xyz) - return atoms, xyz - - -def read_xyz_qm9(fp: PathLike) -> tuple[list[int], list[list[float]]]: - """ - Read the xyz files of the QM9 data set. The xyz files here do not conform - with the standard format. - - Parameters - ---------- - fp : PathLike - Path to coordinate file. - - Returns - ------- - tuple[list[int], list[list[float]]] - Lists containing the atomic numbers and coordinates. - """ - atoms = [] - xyz = [] - num_atoms = 0 - - with open(fp, encoding="utf-8") as file: - lines = file.readlines() - - num_atoms = int(lines[0].strip()) - - for i in range(2, 2 + num_atoms): - l = lines[i].strip().split() - - atoms.append(pse.S2Z[l[0].title()]) - xyz.append([float(x.replace("*^", "e")) * units.AA2AU for x in l[1:4]]) - - if len(xyz) != num_atoms: - raise ValueError(f"Number of atoms in {fp} does not match.") - - xyz = check_xyz(fp, xyz) - return atoms, xyz - - -def read_qcschema(fp: PathLike) -> tuple[list[int], list[list[float]]]: - """ - Read json/QCSchema file. - - Parameters - ---------- - fp : PathLike - Path to coord file. - - Returns - ------- - tuple[list[int], list[list[float]]] - Lists containing the atomic numbers and coordinates. - """ - with open(fp, encoding="utf-8") as file: - data = json_load(file.read()) - - if "molecule" not in data: - raise KeyError(f"Invalid schema: Key 'molecule' not found in '{fp}'.") - - mol = data["molecule"] - - if "symbols" not in mol: - raise KeyError(f"Invalid schema: Key 'symbols' not found in '{fp}'.") - if "geometry" not in mol: - raise KeyError(f"Invalid schema: Key 'geometry' not found in '{fp}'.") - - atoms = [] - for atom in mol["symbols"]: - atoms.append(pse.S2Z[atom.title()]) - - xyz = [] - geo = mol["geometry"] - for i in range(0, len(geo), 3): - xyz.append([float(geo[i]), float(geo[i + 1]), float(geo[i + 2])]) - - xyz = check_xyz(fp, xyz) - return atoms, xyz - - -def read_coord(fp: PathLike) -> tuple[list[int], list[list[float]]]: - """ - Read Turbomole/coord file. - - Parameters - ---------- - fp : PathLike - Path to coord file. - - Returns - ------- - tuple[list[int], list[list[float]]] - Lists containing the atomic numbers and coordinates. - """ - atoms = [] - xyz = [] - breakpoints = ["$user-defined bonds", "$redundant", "$end", "$periodic"] - - with open(fp, encoding="utf-8") as file: - for line in file: # pragma: no branch - # tests exist but somehow not covered? - l = line.split() - - # skip - if len(l) == 0: - continue - elif any(bp in line for bp in breakpoints): - break - elif l[0].startswith("$"): - continue - - if len(l) != 4: - raise ValueError(f"Format error in {fp}") - - x, y, z, atom = l - xyz.append([float(x), float(y), float(z)]) - atoms.append(pse.S2Z[atom.title()]) - - xyz = check_xyz(fp, xyz) - return atoms, xyz - - -def read_chrg(fp: PathLike) -> int: - """Read a chrg (or uhf) file.""" - - if not Path(fp).is_file(): - return 0 - - with open(fp, encoding="utf-8") as file: - return int(file.read()) - - -def read_tblite_gfn(fp: Path | str) -> dict[str, Any]: - """Read energy file from tblite json output.""" - with open(fp, encoding="utf-8") as file: - return json_load(file.read()) - - -def read_orca_engrad(fp: Path | str) -> tuple[float, list[float]]: - """Read ORCA's engrad file.""" - start_grad = -1 - grad = [] - - start_energy = -1 - energy = 0.0 - with open(fp, encoding="utf-8") as file: - for i, line in enumerate(file): # pragma: no branch - # tests exist but somehow not covered? - - # energy - if line.startswith("# The current total energy in Eh"): - start_energy = i + 2 - - if i == start_energy: - l = line.strip() - if len(l) == 0: - raise ValueError(f"No energy found in {fp}.") - energy = float(l) - start_energy = -1 - - # gradient - if line.startswith("# The current gradient in Eh/bohr"): - start_grad = i + 2 - - if i == start_grad: - # abort if we hit the next "#" - if line.startswith("#"): - break - - l = line.strip() - if len(l) == 0: - raise ValueError(f"No gradient found in {fp}.") - - grad.append(float(l)) - start_grad += 1 - - return energy, torch.tensor(grad).reshape(-1, 3).tolist() diff --git a/src/dxtb/_src/loader/lazy/lazy_var.py b/src/dxtb/_src/loader/lazy/lazy_var.py index 9f869e4a1..a3874b674 100644 --- a/src/dxtb/_src/loader/lazy/lazy_var.py +++ b/src/dxtb/_src/loader/lazy/lazy_var.py @@ -49,7 +49,7 @@ import importlib -from dxtb._src.typing import Any, Callable, Sequence +from dxtb._src.typing import Any, Callable, Mapping, Sequence __all__ = ["attach_var", "attach_vars"] @@ -117,7 +117,7 @@ def __dir__() -> list[str]: return __getattr__, __dir__, __all__ -def attach_vars(module_vars: dict[str, Sequence[str]]) -> tuple[ +def attach_vars(module_vars: Mapping[str, Sequence[str]]) -> tuple[ Callable[[str], Any], Callable[[], list[str]], list[str], diff --git a/src/dxtb/_src/param/__init__.py b/src/dxtb/_src/param/__init__.py index 9c68a27e0..944f0c4c7 100644 --- a/src/dxtb/_src/param/__init__.py +++ b/src/dxtb/_src/param/__init__.py @@ -54,7 +54,7 @@ class also supports reading JSON and YAML formats. from pydantic import __version__ as pydantic_version -if tuple(map(int, pydantic_version.split("."))) < (2, 0, 0): +if tuple(map(int, pydantic_version.split("."))) < (2, 0, 0): # pragma: no cover raise RuntimeError( "pydantic version outdated: dxtb requires pydantic >=2.0.0 " f"(version {pydantic_version} installed)." diff --git a/src/dxtb/_src/scf/iterator.py b/src/dxtb/_src/scf/iterator.py index 4b8388766..060459745 100644 --- a/src/dxtb/_src/scf/iterator.py +++ b/src/dxtb/_src/scf/iterator.py @@ -91,6 +91,12 @@ def solve( n0, occupation = get_refocc(refocc, chrg, spin, ihelp) charges = get_guess(numbers, positions, chrg, ihelp, config.guess) + if not isinstance(config.scf_mode, int): + raise ValueError( + "SCF mode must be an integer within `solve`. This can only " + "happen if you explicitly change the configuration object." + ) + if config.scf_mode == labels.SCF_MODE_IMPLICIT: # pylint: disable=import-outside-toplevel from .pure import scf_wrapper @@ -119,8 +125,7 @@ def solve( # pylint: disable=import-outside-toplevel from .unrolling import SelfConsistentFieldSingleShot as SCF else: - name = labels.SCF_MODE_MAP[config.scf_mode] - raise ValueError(f"Unknown SCF mode '{name}' (input name can vary).") + raise ValueError(f"Unknown SCF mode '{config.scf_mode}'.") return SCF( interactions, diff --git a/src/dxtb/_src/scf/pure/config.py b/src/dxtb/_src/scf/pure/config.py deleted file mode 100644 index ffc582c7d..000000000 --- a/src/dxtb/_src/scf/pure/config.py +++ /dev/null @@ -1,106 +0,0 @@ -""" -SCF Configuration -================= - -This module defines a storage class for the SCF options. -""" - -from __future__ import annotations - -import torch -from tad_mctc.units.energy import KELVIN2AU - -from dxtb._src.constants import defaults -from dxtb._src.typing import Any - -from .data import _Data - -__all__ = ["SCFConfig"] - - -class SCFConfig: - """ - Self-consistent field configuration, as pure base class containing only - configuration information. - - This class should _not_ contain any tensors, which store AG gradients - during SCF iterations. - """ - - fwd_options: dict[str, Any] - """Options for forwards pass""" - - bck_options: dict[str, Any] - """Options for backwards pass""" - - eigen_options: dict[str, Any] - """Options for eigensolver""" - - scf_options: dict[str, Any] - """ - Options for SCF: - - - "etemp": Electronic temperature (in a.u.) for Fermi smearing. - - "fermi_maxiter": Maximum number of iterations for Fermi smearing. - - "fermi_thresh": Float data type dependent threshold for Fermi iterations. - - "fermi_fenergy_partition": Partitioning scheme for electronic free energy. - """ - - use_potential: bool - """Whether to use the potential or the charges""" - - batch_mode: int - """Whether multiple systems or a single one are handled""" - - def __init__(self, data: _Data, batch_mode: int, **kwargs: Any) -> None: - self.bck_options = {"posdef": True, **kwargs.pop("bck_options", {})} - self.fwd_options = { - "force_convergence": False, - "method": "broyden1", - "alpha": -0.5, - "f_tol": defaults.F_ATOL, - "x_tol": defaults.X_ATOL, - "f_rtol": float("inf"), - "x_rtol": float("inf"), - "maxiter": defaults.MAXITER, - "verbose": False, - "line_search": False, - **kwargs.pop("fwd_options", {}), - } - - self.eigen_options = {"method": "exacteig", **kwargs.pop("eigen_options", {})} - - self.scf_options = {**kwargs.pop("scf_options", {})} - self.scp_mode = self.scf_options.get("scp_mode", defaults.SCP_MODE) - - # Only infer shapes and types from _Data (no logic involved), - # i.e. keep _Data and SCFConfig instances disjunct objects. - self._shape = data.ints.hcore.shape - self._dtype = data.ints.hcore.dtype - self._device = data.ints.hcore.device - - self.kt = data.ints.hcore.new_tensor( - self.scf_options.get("etemp", defaults.FERMI_ETEMP) * KELVIN2AU - ) - self.batch_mode = batch_mode - - @property - def shape(self) -> torch.Size: - """ - Returns the shape of the density matrix in this engine. - """ - return self._shape - - @property - def dtype(self) -> torch.dtype: - """ - Returns the dtype of the tensors in this engine. - """ - return self._dtype - - @property - def device(self) -> torch.device: - """ - Returns the device of the tensors in this engine. - """ - return self._device diff --git a/src/dxtb/_src/typing/builtin.py b/src/dxtb/_src/typing/builtin.py index e60b76244..cfe61fb6a 100644 --- a/src/dxtb/_src/typing/builtin.py +++ b/src/dxtb/_src/typing/builtin.py @@ -28,6 +28,7 @@ Iterable, Iterator, Literal, + Mapping, NoReturn, Protocol, Type, @@ -46,6 +47,7 @@ "Iterable", "Iterator", "Literal", + "Mapping", "NoReturn", "Protocol", "Type", diff --git a/test/test_hamiltonian/test_h0.py b/test/test_hamiltonian/test_h0.py index f88f953ec..abd382ff7 100644 --- a/test/test_hamiltonian/test_h0.py +++ b/test/test_hamiltonian/test_h0.py @@ -153,3 +153,56 @@ def test_large_batch(dtype: torch.dtype, name1: str, name2: str) -> None: ) run(numbers, positions, GFN1_XTB, ref, dd) + + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_no_cn(dtype: torch.dtype) -> None: + """Test without CN.""" + tol = sqrt(torch.finfo(dtype).eps) + dd: DD = {"dtype": dtype, "device": DEVICE} + + sample = samples["H2"] + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + + ref = torch.tensor( + [ + [ + -0.40142945681830, + -0.00000000000000, + -0.47765679842079, + -0.03687145777483, + ], + [ + -0.00000000000000, + -0.07981592633195, + -0.03687145777483, + -0.02334876845340, + ], + [ + -0.47765679842079, + -0.03687145777483, + -0.40142945681830, + -0.00000000000000, + ], + [ + -0.03687145777483, + -0.02334876845340, + -0.00000000000000, + -0.07981592633195, + ], + ], + **dd, + ) + + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB) + driver = IntDriver(numbers, GFN1_XTB, ihelp, **dd) + overlap = Overlap(**dd) + h0 = GFN1Hamiltonian(numbers, GFN1_XTB, ihelp, **dd) + + driver.setup(positions) + s = overlap.build(driver) + + h = h0.build(positions, s) + assert pytest.approx(h.cpu(), abs=tol) == h.mT.cpu() + assert pytest.approx(h.cpu(), abs=tol) == ref.cpu() diff --git a/test/test_integrals/test_types.py b/test/test_integrals/test_types.py new file mode 100644 index 000000000..ff8f13796 --- /dev/null +++ b/test/test_integrals/test_types.py @@ -0,0 +1,40 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test overlap build from integral container. +""" + +from __future__ import annotations + +import pytest +import torch + +from dxtb import GFN1_XTB, GFN2_XTB, IndexHelper +from dxtb.integrals import types as inttypes + +numbers = torch.tensor([14, 1, 1, 1, 1]) + + +def test_fail() -> None: + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB) + + with pytest.raises(ValueError): + par1 = GFN1_XTB.model_copy(deep=True) + assert par1.meta is not None + + par1.meta.name = "fail" + inttypes.HCore(numbers, par1, ihelp) diff --git a/test/test_integrals/test_wrappers.py b/test/test_integrals/test_wrappers.py new file mode 100644 index 000000000..547e77410 --- /dev/null +++ b/test/test_integrals/test_wrappers.py @@ -0,0 +1,95 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test overlap build from integral container. +""" + +from __future__ import annotations + +import pytest +import torch + +from dxtb import GFN1_XTB, GFN2_XTB, Param +from dxtb.integrals import wrappers + +numbers = torch.tensor([14, 1, 1, 1, 1]) +positions = torch.tensor( + [ + [+0.00000000000000, +0.00000000000000, +0.00000000000000], + [+1.61768389755830, +1.61768389755830, -1.61768389755830], + [-1.61768389755830, -1.61768389755830, -1.61768389755830], + [+1.61768389755830, -1.61768389755830, +1.61768389755830], + [-1.61768389755830, +1.61768389755830, +1.61768389755830], + ] +) + + +def test_fail() -> None: + with pytest.raises(TypeError): + par1 = GFN1_XTB.model_copy(deep=True) + par1.meta = None + wrappers.hcore(numbers, positions, par1) + + with pytest.raises(TypeError): + par1 = GFN1_XTB.model_copy(deep=True) + assert par1.meta is not None + + par1.meta.name = None + wrappers.hcore(numbers, positions, par1) + + with pytest.raises(ValueError): + par1 = GFN1_XTB.model_copy(deep=True) + assert par1.meta is not None + + par1.meta.name = "fail" + wrappers.hcore(numbers, positions, par1) + + with pytest.raises(ValueError): + # pylint: disable=import-outside-toplevel + from dxtb._src.integral.wrappers import _integral + + _integral("fail", numbers, positions, par1) # type: ignore + + +@pytest.mark.parametrize("par", [GFN1_XTB]) +def test_h0_gfn1(par: Param) -> None: + h0 = wrappers.hcore(numbers, positions, par) + assert h0.shape == (17, 17) + + h0 = wrappers.hcore(numbers, positions, par, cn=torch.zeros(numbers.shape)) + assert h0.shape == (17, 17) + + +@pytest.mark.parametrize("par", [GFN2_XTB]) +def test_h0_gfn2(par: Param) -> None: + with pytest.raises(NotImplementedError): + wrappers.hcore(numbers, positions, par) + + +def test_overlap() -> None: + s = wrappers.overlap(numbers, positions, GFN1_XTB) + assert s.shape == (17, 17) + + +def test_dipole() -> None: + s = wrappers.dipint(numbers, positions, GFN1_XTB) + assert s.shape == (3, 17, 17) + + +def test_quad() -> None: + s = wrappers.quadint(numbers, positions, GFN1_XTB) + assert s.shape == (9, 17, 17) diff --git a/test/test_io/__init__.py b/test/test_io/__init__.py new file mode 100644 index 000000000..15d042be4 --- /dev/null +++ b/test/test_io/__init__.py @@ -0,0 +1,16 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/test/test_io/test_logging.py b/test/test_io/test_logging.py new file mode 100644 index 000000000..a5f823c92 --- /dev/null +++ b/test/test_io/test_logging.py @@ -0,0 +1,32 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test logging utils. +""" +from __future__ import annotations + +from dxtb._src.io.logutils import DEFAULT_LOG_CONFIG, get_logging_config + + +def test_config(): + config = get_logging_config() + assert config["level"] == DEFAULT_LOG_CONFIG["level"] + assert config["format"] == DEFAULT_LOG_CONFIG["format"] + assert config["datefmt"] == DEFAULT_LOG_CONFIG["datefmt"] + + config = get_logging_config(level="debug") + assert config["level"] == "debug" diff --git a/test/test_io/test_outputs.py b/test/test_io/test_outputs.py new file mode 100644 index 000000000..d515bd784 --- /dev/null +++ b/test/test_io/test_outputs.py @@ -0,0 +1,106 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test output. +""" +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from dxtb import __version__ +from dxtb._src.io.output import ( + get_mkl_num_threads, + get_omp_num_threads, + get_python_version, + get_pytorch_version_short, + get_short_version, + get_system_info, +) + + +@patch("platform.python_version") +def test_python_version(mocker) -> None: + mocker.return_value = "3.8.5" + assert get_python_version() == "3.8.5" + + +@patch("torch.__config__.show") +def test_get_pytorch_version_short(mocker) -> None: + mocker.return_value = "config,TORCH_VERSION=1.7.1,other" + assert get_pytorch_version_short() == "1.7.1" + + +@patch("torch.__config__.show") +def test_get_pytorch_version_short_raises_error(mocker) -> None: + mocker.return_value = "config,other" + + with pytest.raises(RuntimeError, match="Version string not found in config."): + get_pytorch_version_short() + + +@patch("platform.python_version") +@patch("torch.__config__.show") +def test_get_short_version(mocker_torch, mocker_python) -> None: + mocker_torch.return_value = "config,TORCH_VERSION=1.7.1,other" + mocker_python.return_value = "3.8.5" + + msg = f"* dxtb version {__version__} running with Python 3.8.5 and PyTorch 1.7.1\n" + assert get_short_version() == msg + + +############################################################################### + + +def test_get_omp_num_threads() -> None: + # Mock torch.__config__.parallel_info to return a controlled string + mock_parallel_info = MagicMock() + mock_parallel_info.return_value = "some_info\nOMP_NUM_THREADS=4\nother_info" + + with patch("torch.__config__.parallel_info", mock_parallel_info): + omp_num_threads = get_omp_num_threads() + assert omp_num_threads == "OMP_NUM_THREADS=4" + + +def test_get_mkl_num_threads() -> None: + # Mock torch.__config__.parallel_info to return a controlled string + mock_parallel_info = MagicMock() + mock_parallel_info.return_value = "some_info\nMKL_NUM_THREADS=8\nother_info" + + with patch("torch.__config__.parallel_info", mock_parallel_info): + mkl_num_threads = get_mkl_num_threads() + assert mkl_num_threads == "MKL_NUM_THREADS=8" + + +def test_get_system_info() -> None: + with patch("platform.system", return_value="Linux"): + with patch("platform.machine", return_value="x86_64"): + with patch("platform.release", return_value="5.4.0-74-generic"): + with patch("platform.node", return_value="test-host"): + with patch("os.cpu_count", return_value=8): + system_info = get_system_info() + expected_info = { + "System Information": { + "Operating System": "Linux", + "Architecture": "x86_64", + "OS Version": "5.4.0-74-generic", + "Hostname": "test-host", + "CPU Count": 8, + } + } + assert system_info == expected_info diff --git a/test/test_loader/__init__.py b/test/test_loader/__init__.py new file mode 100644 index 000000000..15d042be4 --- /dev/null +++ b/test/test_loader/__init__.py @@ -0,0 +1,16 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/test/test_loader/test_lazy/__init__.py b/test/test_loader/test_lazy/__init__.py new file mode 100644 index 000000000..15d042be4 --- /dev/null +++ b/test/test_loader/test_lazy/__init__.py @@ -0,0 +1,16 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/test/test_loader/test_lazy/test_attach_module.py b/test/test_loader/test_lazy/test_attach_module.py new file mode 100644 index 000000000..048b8327d --- /dev/null +++ b/test/test_loader/test_lazy/test_attach_module.py @@ -0,0 +1,64 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test the lazy loaders. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from dxtb._src.loader.lazy import attach_module + + +def test_attach_module_imports_submodules(): + package_name = "test_package" + submodules = ["sub1", "sub2"] + + # Mock importlib.import_module to simulate module imports + with patch("importlib.import_module") as mock_import_module: + mock_import_module.side_effect = lambda name: f"module_{name}" + + __getattr__, __dir__, __all__ = attach_module(package_name, submodules) + + # Test __getattr__ for existing submodules + assert __getattr__("sub1") == "module_test_package.sub1" + assert __getattr__("sub2") == "module_test_package.sub2" + + # Test __dir__ returns the list of submodules + assert __dir__() == submodules + + # Test __all__ contains the submodules + assert __all__ == submodules + + +def test_attach_module_raises_attribute_error_for_nonexistent_submodules(): + package_name = "test_package" + submodules = ["sub1", "sub2"] + + # Mock importlib.import_module to simulate module imports + with patch("importlib.import_module") as mock_import_module: + mock_import_module.side_effect = lambda name: f"module_{name}" + + __getattr__, __dir__, __all__ = attach_module(package_name, submodules) + + # Test __getattr__ raises AttributeError for non-existent submodules + msg = f"The module '{package_name}' has no attribute 'sub3." + with pytest.raises(AttributeError, match=msg): + __getattr__("sub3") diff --git a/test/test_loader/test_lazy/test_attach_var.py b/test/test_loader/test_lazy/test_attach_var.py new file mode 100644 index 000000000..7229d89d4 --- /dev/null +++ b/test/test_loader/test_lazy/test_attach_var.py @@ -0,0 +1,122 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test the lazy loaders. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from dxtb._src.loader.lazy import attach_var, attach_vars + + +def test_attach_var_imports_variables(): + package_name = "test_package" + varnames = ["var1", "var2"] + + # Mock importlib.import_module to simulate module imports + with patch("importlib.import_module") as mock_import_module: + mock_module = MagicMock() + mock_module.var1 = "value1" + mock_module.var2 = "value2" + mock_import_module.return_value = mock_module + + __getattr__, __dir__, __all__ = attach_var(package_name, varnames) + + # Test __getattr__ for existing variables + assert __getattr__("var1") == "value1" + assert __getattr__("var2") == "value2" + + # Test __dir__ returns the list of variables + assert __dir__() == varnames + + # Test __all__ contains the variables + assert __all__ == varnames + + +def test_attach_var_raises_attribute_error_for_nonexistent_variables(): + package_name = "test_package" + varnames = ["var1", "var2"] + + # Mock importlib.import_module to simulate module imports + with patch("importlib.import_module") as mock_import_module: + mock_module = MagicMock() + mock_module.var1 = "value1" + mock_import_module.return_value = mock_module + + __getattr__, __dir__, __all__ = attach_var(package_name, varnames) + + # Test __getattr__ raises AttributeError for non-existent variables + msg = f"The module '{package_name}' has no attribute 'var3." + with pytest.raises(AttributeError, match=msg): + __getattr__("var3") + + +def test_attach_vars_imports_variables(): + module_vars = {"package1": ["var1", "var2"], "package2": ["var3", "var4"]} + + # Mock importlib.import_module to simulate module imports + with patch("importlib.import_module") as mock_import_module: + mock_package1 = MagicMock() + mock_package1.var1 = "value1" + mock_package1.var2 = "value2" + mock_package2 = MagicMock() + mock_package2.var3 = "value3" + mock_package2.var4 = "value4" + mock_import_module.side_effect = lambda name: ( + mock_package1 if name == "package1" else mock_package2 + ) + + __getattr__, __dir__, __all__ = attach_vars(module_vars) + + # Test __getattr__ for existing variables + assert __getattr__("var1") == "value1" + assert __getattr__("var2") == "value2" + assert __getattr__("var3") == "value3" + assert __getattr__("var4") == "value4" + + # Test __dir__ returns the list of variables + assert __dir__() == ["var1", "var2", "var3", "var4"] + + # Test __all__ contains the variables + assert __all__ == ["var1", "var2", "var3", "var4"] + + +def test_attach_vars_raises_attribute_error_for_nonexistent_variables(): + module_vars = {"package1": ["var1", "var2"], "package2": ["var3", "var4"]} + + # Mock importlib.import_module to simulate module imports + with patch("importlib.import_module") as mock_import_module: + mock_package1 = MagicMock() + mock_package1.var1 = "value1" + mock_package1.var2 = "value2" + mock_package2 = MagicMock() + mock_package2.var3 = "value3" + mock_package2.var4 = "value4" + mock_import_module.side_effect = lambda name: ( + mock_package1 if name == "package1" else mock_package2 + ) + + __getattr__, __dir__, __all__ = attach_vars(module_vars) + + # Test __getattr__ raises AttributeError for non-existent variables + msg = f"No module contains the variable 'var5'." + with pytest.raises(AttributeError, match=msg): + __getattr__("var5") diff --git a/test/test_loader/test_lazy/test_param.py b/test/test_loader/test_lazy/test_param.py new file mode 100644 index 000000000..d18a2037a --- /dev/null +++ b/test/test_loader/test_lazy/test_param.py @@ -0,0 +1,69 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test the lazy loaders. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from dxtb._src.loader.lazy import LazyLoaderParam + + +def test_lazy_loader_param_initialization() -> None: + filepath = "test.toml" + loader = LazyLoaderParam(filepath) + assert loader.filepath == filepath + assert loader._loaded is None + + +def test_lazy_loader_param_str() -> None: + filepath = "test.toml" + loader = LazyLoaderParam(filepath) + assert str(loader) == f"LazyLoaderParam({filepath})" + + +def test_lazy_loader_param_repr() -> None: + filepath = "test.toml" + loader = LazyLoaderParam(filepath) + assert repr(loader) == f"LazyLoaderParam({filepath})" + + +@pytest.mark.parametrize("parname", ["gfn1-xtb", "gfn2-xtb"]) +def test_lazy_loader_param_equality(parname: str) -> None: + p = ( + Path(__file__).parents[3] + / "src" + / "dxtb" + / "_src" + / "param" + / parname.split("-")[0] + / f"{parname}.toml" + ) + + loader1 = LazyLoaderParam(p) + loader2 = LazyLoaderParam(p) + + # Trigger the lazy loading + _ = loader1.meta + _ = loader2.meta + + assert loader1 == loader2 + assert loader1._loaded == loader2._loaded diff --git a/test/test_scf/test_general.py b/test/test_scf/test_general.py index 18c1af109..83cc45bed 100644 --- a/test/test_scf/test_general.py +++ b/test/test_scf/test_general.py @@ -43,3 +43,21 @@ def test_properties() -> None: assert scf.shape == d.shape assert scf.device == d.device assert scf.dtype == d.dtype + + +def test_fail() -> None: + from dxtb import GFN1_XTB + from dxtb.calculators import EnergyCalculator + + numbers = torch.tensor([1]) + positions = torch.tensor([[0.0, 0.0, 0.0]]) + + calc = EnergyCalculator(numbers, GFN1_XTB) + + with pytest.raises(ValueError): + calc.opts.scf.scf_mode = -1 + calc.singlepoint(positions) + + with pytest.raises(ValueError): + calc.opts.scf.scf_mode = "fail" # type: ignore + calc.singlepoint(positions) diff --git a/test/test_scf/test_guess.py b/test/test_scf/test_guess.py index 8544dea76..3c7685ae9 100644 --- a/test/test_scf/test_guess.py +++ b/test/test_scf/test_guess.py @@ -23,7 +23,7 @@ import pytest import torch -from dxtb import IndexHelper +from dxtb import IndexHelper, labels from dxtb._src.scf import guess from ..conftest import DEVICE @@ -44,6 +44,9 @@ def test_fail() -> None: with pytest.raises(ValueError): guess.get_guess(numbers, positions, charge, ihelp, name="eht") + with pytest.raises(ValueError): + guess.get_guess(numbers, positions, charge, ihelp, name=1000) + # charges change because IndexHelper is broken with pytest.raises(RuntimeError): ih = IndexHelper.from_numbers_angular(numbers, {1: [0, 0], 6: [0, 1]}) @@ -51,8 +54,9 @@ def test_fail() -> None: guess.get_guess(numbers, positions, charge, ih) -def test_eeq() -> None: - c = guess.get_guess(numbers, positions, charge, ihelp) +@pytest.mark.parametrize("name", ["eeq", labels.GUESS_EEQ]) +def test_eeq(name: str | int) -> None: + c = guess.get_guess(numbers, positions, charge, ihelp, name=name) ref = torch.tensor( [ -0.11593066900969, @@ -67,8 +71,9 @@ def test_eeq() -> None: assert pytest.approx(ref.cpu(), abs=1e-5) == c.cpu() -def test_sad() -> None: - c = guess.get_guess(numbers, positions, charge, ihelp, name="sad") +@pytest.mark.parametrize("name", ["sad", labels.GUESS_SAD]) +def test_sad(name: str | int) -> None: + c = guess.get_guess(numbers, positions, charge, ihelp, name=name) size = int(ihelp.orbitals_per_shell.sum().item()) assert pytest.approx(torch.zeros(size).cpu()) == c.cpu() diff --git a/test/test_singlepoint/test_energy.py b/test/test_singlepoint/test_energy.py index 2c01e7e18..1e69983cb 100644 --- a/test/test_singlepoint/test_energy.py +++ b/test/test_singlepoint/test_energy.py @@ -26,11 +26,11 @@ import pytest import torch from tad_mctc.batch import pack +from tad_mctc.io import read from dxtb import GFN1_XTB as par from dxtb import Calculator from dxtb._src.constants import labels -from dxtb._src.io import read_chrg, read_coord from dxtb._src.typing import DD from ..conftest import DEVICE @@ -56,8 +56,8 @@ def test_single(dtype: torch.dtype, name: str, scf_mode: str) -> None: base = Path(Path(__file__).parent, "mols", name) - numbers, positions = read_coord(Path(base, "coord")) - charge = read_chrg(Path(base, ".CHRG")) + numbers, positions = read.read_from_path(Path(base, "coord")) + charge = read.read_chrg_from_path(Path(base, ".CHRG")) numbers = torch.tensor(numbers, dtype=torch.long, device=DEVICE) positions = torch.tensor(positions, **dd) @@ -90,8 +90,8 @@ def test_single_large(dtype: torch.dtype, name: str, scf_mode: str) -> None: base = Path(Path(__file__).parent, "mols", name) - numbers, positions = read_coord(Path(base, "coord")) - charge = read_chrg(Path(base, ".CHRG")) + numbers, positions = read.read_from_path(Path(base, "coord")) + charge = read.read_chrg_from_path(Path(base, ".CHRG")) numbers = torch.tensor(numbers, dtype=torch.long, device=DEVICE) positions = torch.tensor(positions, **dd) @@ -128,9 +128,8 @@ def test_batch( numbers, positions, charge = [], [], [] for name in [name1, name2, name3]: base = Path(Path(__file__).parent, "mols", name) - - nums, pos = read_coord(Path(base, "coord")) - chrg = read_chrg(Path(base, ".CHRG")) + nums, pos = read.read_from_path(Path(base, "coord")) + chrg = read.read_chrg_from_path(Path(base, ".CHRG")) numbers.append(torch.tensor(nums, dtype=torch.long, device=DEVICE)) positions.append(torch.tensor(pos, **dd)) @@ -178,8 +177,8 @@ def test_batch_large( for name in [name1, name2, name3]: base = Path(Path(__file__).parent, "mols", name) - nums, pos = read_coord(Path(base, "coord")) - chrg = read_chrg(Path(base, ".CHRG")) + nums, pos = read.read_from_path(Path(base, "coord")) + chrg = read.read_chrg_from_path(Path(base, ".CHRG")) numbers.append(torch.tensor(nums, dtype=torch.long, device=DEVICE)) positions.append(torch.tensor(pos, **dd)) @@ -218,9 +217,8 @@ def test_uhf_single(dtype: torch.dtype, name: str) -> None: dd: DD = {"device": DEVICE, "dtype": dtype} base = Path(Path(__file__).parent, "mols", name) - - numbers, positions = read_coord(Path(base, "coord")) - charge = read_chrg(Path(base, ".CHRG")) + numbers, positions = read.read_from_path(Path(base, "coord")) + charge = read.read_chrg_from_path(Path(base, ".CHRG")) numbers = torch.tensor(numbers, dtype=torch.long, device=DEVICE) positions = torch.tensor(positions, **dd) diff --git a/test/test_singlepoint/test_general.py b/test/test_singlepoint/test_general.py index 0dc8f2a30..b1094457f 100644 --- a/test/test_singlepoint/test_general.py +++ b/test/test_singlepoint/test_general.py @@ -24,10 +24,10 @@ import pytest import torch +from tad_mctc.io import read from dxtb import GFN1_XTB as par from dxtb import Calculator -from dxtb._src.io import read_chrg, read_coord from dxtb._src.timing import timer from ..conftest import DEVICE @@ -35,11 +35,6 @@ opts = {"verbosity": 0, "int_level": 4} -def test_fail() -> None: - with pytest.raises(FileNotFoundError): - read_coord(Path("non-existing-coord-file")) - - def test_uhf_fail() -> None: # singlepoint starts SCF timer, but exception is thrown before the SCF # timer is stopped, so we must disable it here @@ -49,8 +44,8 @@ def test_uhf_fail() -> None: base = Path(Path(__file__).parent, "mols", "H") - numbers, positions = read_coord(Path(base, "coord")) - charge = read_chrg(Path(base, ".CHRG")) + numbers, positions = read.read_from_path(Path(base, "coord")) + charge = read.read_chrg_from_path(Path(base, ".CHRG")) numbers = torch.tensor(numbers, dtype=torch.long) positions = torch.tensor(positions) diff --git a/test/test_singlepoint/test_grad.py b/test/test_singlepoint/test_grad.py index 7246a001e..b7f8fc252 100644 --- a/test/test_singlepoint/test_grad.py +++ b/test/test_singlepoint/test_grad.py @@ -26,11 +26,11 @@ import numpy as np import pytest import torch +from tad_mctc.io import read from dxtb import GFN1_XTB as par from dxtb import Calculator from dxtb._src.constants import labels -from dxtb._src.io import read_chrg, read_coord from dxtb._src.typing import DD, Tensor from ..conftest import DEVICE @@ -91,8 +91,8 @@ def analytical( # read from file base = Path(Path(__file__).parent, "mols", name) - numbers, positions = read_coord(Path(base, "coord")) - charge = read_chrg(Path(base, ".CHRG")) + numbers, positions = read.read_from_path(Path(base, "coord")) + charge = read.read_chrg_from_path(Path(base, ".CHRG")) # convert to tensors numbers = torch.tensor(numbers, dtype=torch.long, device=DEVICE) @@ -127,8 +127,8 @@ def test_backward(dtype: torch.dtype, name: str, scf_mode: str) -> None: # read from file base = Path(Path(__file__).parent, "mols", name) - numbers, positions = read_coord(Path(base, "coord")) - charge = read_chrg(Path(base, ".CHRG")) + numbers, positions = read.read_from_path(Path(base, "coord")) + charge = read.read_chrg_from_path(Path(base, ".CHRG")) # convert to tensors numbers = torch.tensor(numbers, dtype=torch.long, device=DEVICE) @@ -173,8 +173,8 @@ def test_num(name: str, scf_mode: str) -> None: # read from file base = Path(Path(__file__).parent, "mols", name) - numbers, positions = read_coord(Path(base, "coord")) - charge = read_chrg(Path(base, ".CHRG")) + numbers, positions = read.read_from_path(Path(base, "coord")) + charge = read.read_chrg_from_path(Path(base, ".CHRG")) # convert to tensors numbers = torch.tensor(numbers, dtype=torch.long, device=DEVICE) diff --git a/test/test_singlepoint/test_hess.py b/test/test_singlepoint/test_hess.py index 767b6c4c1..08e23a243 100644 --- a/test/test_singlepoint/test_hess.py +++ b/test/test_singlepoint/test_hess.py @@ -27,11 +27,11 @@ import torch from tad_mctc.autograd import jacrev from tad_mctc.convert import reshape_fortran +from tad_mctc.io import read from dxtb import GFN1_XTB as par from dxtb import Calculator from dxtb._src.constants import labels -from dxtb._src.io import read_chrg, read_coord from dxtb._src.typing import DD, Tensor from ..conftest import DEVICE @@ -59,8 +59,8 @@ def test_single(dtype: torch.dtype, name: str) -> None: # read from file base = Path(Path(__file__).parent, "mols", name) - numbers, positions = read_coord(Path(base, "coord")) - charge = read_chrg(Path(base, ".CHRG")) + numbers, positions = read.read_from_path(Path(base, "coord")) + charge = read.read_chrg_from_path(Path(base, ".CHRG")) # convert to tensors numbers = torch.tensor(numbers, dtype=torch.long, device=DEVICE) diff --git a/test/test_utils/test_timer.py b/test/test_utils/test_timer.py index 1c26a16c3..f1e8b88e5 100644 --- a/test/test_utils/test_timer.py +++ b/test/test_utils/test_timer.py @@ -20,9 +20,11 @@ from __future__ import annotations +from unittest.mock import patch + import pytest -from dxtb._src.timing.timer import TimerError, _Timers +from dxtb._src.timing.timer import TimerError, _sync, _Timers def test_fail() -> None: @@ -63,3 +65,21 @@ def test_stopall() -> None: assert not timer.timers["test"].is_running() assert not timer.timers["test2"].is_running() + + +@patch("torch.cuda.synchronize") +@patch("torch.cuda.is_available", return_value=False) +def test_sync_false(mocker_avail, mocker_sync) -> None: + _sync() + + mocker_avail.assert_called_once() + mocker_sync.assert_not_called() + + +@patch("torch.cuda.synchronize") +@patch("torch.cuda.is_available", return_value=True) +def test_sync_true(mocker_avail, mocker_sync) -> None: + _sync() + + mocker_avail.assert_called_once() + mocker_sync.assert_called_once()