diff --git a/janus_core/calculations/base.py b/janus_core/calculations/base.py index eb8b3756..554345be 100644 --- a/janus_core/calculations/base.py +++ b/janus_core/calculations/base.py @@ -18,6 +18,20 @@ from janus_core.helpers.struct_io import input_structs from janus_core.helpers.utils import FileNameMixin, none_to_dict +UNITS = { + "energy": "eV", + "forces": "ev/Ang", + "stress": "ev/Ang^3", + "hessian": "ev/Ang^2", + "time": "fs", + "real_time": "s", + "temperature": "K", + "pressure": "GPa", + "momenta": "(eV*u)^0.5", + "density": "g/cm^3", + "volume": "Ang^3", +} + class BaseCalculation(FileNameMixin): """ @@ -202,3 +216,21 @@ def __init__( self.tracker = config_tracker( self.logger, self.track_carbon, **self.tracker_kwargs ) + + def _set_info_units( + self, keys: Sequence[str] = ("energy", "forces", "stress") + ) -> None: + """ + Save units to structure info. + + Parameters + ---------- + keys : Sequence + Keys for which to add units to structure info. Default is + ("energy", "forces", "stress"). + """ + if isinstance(self.struct, Sequence): + for image in self.struct: + image.info["units"] = {key: UNITS[key] for key in keys} + else: + self.struct.info["units"] = {key: UNITS[key] for key in keys} diff --git a/janus_core/calculations/eos.py b/janus_core/calculations/eos.py index bd7bc863..e6bddf6a 100644 --- a/janus_core/calculations/eos.py +++ b/janus_core/calculations/eos.py @@ -290,6 +290,8 @@ def run(self) -> EoSResults: Dictionary containing equation of state ASE object, and the fitted minimum bulk modulus, volume, and energy. """ + self._set_info_units() + if self.minimize: if self.logger: self.logger.info("Minimising initial structure") diff --git a/janus_core/calculations/geom_opt.py b/janus_core/calculations/geom_opt.py index 6077fddd..f4af4eec 100644 --- a/janus_core/calculations/geom_opt.py +++ b/janus_core/calculations/geom_opt.py @@ -304,6 +304,8 @@ def run(self) -> None: if self.tracker: self.tracker.start_task("Geometry optimization") + self._set_info_units() + converged = self.dyn.run(fmax=self.fmax, steps=self.steps) # Calculate current maximum force diff --git a/janus_core/calculations/md.py b/janus_core/calculations/md.py index 4a1119e0..cf71924a 100644 --- a/janus_core/calculations/md.py +++ b/janus_core/calculations/md.py @@ -12,7 +12,7 @@ from typing import Any from warnings import warn -from ase import Atoms, units +from ase import Atoms from ase.geometry.analysis import Analysis from ase.io import read from ase.md.langevin import Langevin @@ -23,9 +23,11 @@ ZeroRotation, ) from ase.md.verlet import VelocityVerlet +from ase.units import create_units import numpy as np import yaml +from janus_core.calculations.base import UNITS as JANUS_UNITS from janus_core.calculations.base import BaseCalculation from janus_core.calculations.geom_opt import GeomOpt from janus_core.helpers.janus_types import ( @@ -43,6 +45,7 @@ from janus_core.processing.correlator import Correlation from janus_core.processing.post_process import compute_rdf, compute_vaf +units = create_units("2014") DENS_FACT = (units.m / 1.0e2) ** 3 / units.mol @@ -520,7 +523,7 @@ def _set_info(self) -> None: """Set time in fs, current dynamics step, and density to info.""" time = (self.offset * self.timestep + self.dyn.get_time()) / units.fs step = self.offset + self.dyn.nsteps - self.dyn.atoms.info["time_fs"] = time + self.dyn.atoms.info["time"] = time self.dyn.atoms.info["step"] = step try: density = ( @@ -769,7 +772,7 @@ def get_stats(self) -> dict[str, float]: return { "Step": self.dyn.atoms.info["step"], "Real_Time": real_time.total_seconds(), - "Time": self.dyn.atoms.info["time_fs"], + "Time": self.dyn.atoms.info["time"], "Epot/N": e_pot, "EKin/N": e_kin, "T": current_temp, @@ -797,21 +800,21 @@ def unit_info(self) -> dict[str, str]: """ return { "Step": None, - "Real_Time": "s", - "Time": "fs", - "Epot/N": "eV", - "EKin/N": "eV", - "T": "K", - "ETot/N": "eV", - "Density": "g/cm^3", - "Volume": "A^3", - "P": "GPa", - "Pxx": "GPa", - "Pyy": "GPa", - "Pzz": "GPa", - "Pyz": "GPa", - "Pxz": "GPa", - "Pxy": "GPa", + "Real_Time": JANUS_UNITS["real_time"], + "Time": JANUS_UNITS["time"], + "Epot/N": JANUS_UNITS["energy"], + "EKin/N": JANUS_UNITS["energy"], + "T": JANUS_UNITS["temperature"], + "ETot/N": JANUS_UNITS["energy"], + "Density": JANUS_UNITS["density"], + "Volume": JANUS_UNITS["volume"], + "P": JANUS_UNITS["pressure"], + "Pxx": JANUS_UNITS["pressure"], + "Pyy": JANUS_UNITS["pressure"], + "Pzz": JANUS_UNITS["pressure"], + "Pyz": JANUS_UNITS["pressure"], + "Pxz": JANUS_UNITS["pressure"], + "Pxy": JANUS_UNITS["pressure"], } @property @@ -1021,6 +1024,19 @@ def _write_restart(self) -> None: def run(self) -> None: """Run molecular dynamics simulation and/or temperature ramp.""" + unit_keys = ( + "energy", + "forces", + "stress", + "time", + "real_time", + "temperature", + "pressure", + "density", + "momenta", + ) + self._set_info_units(unit_keys) + if not self.restart: if self.minimize: self._optimize_structure() @@ -1262,7 +1278,10 @@ def unit_info(self) -> dict[str, str]: dict[str, str] Units attached to statistical properties. """ - return super().unit_info | {"Target_P": "GPa", "Target_T": "K"} + return super().unit_info | { + "Target_P": JANUS_UNITS["pressure"], + "Target_T": JANUS_UNITS["temperature"], + } @property def default_formats(self) -> dict[str, str]: @@ -1359,7 +1378,7 @@ def unit_info(self) -> dict[str, str]: dict[str, str] Units attached to statistical properties. """ - return super().unit_info | {"Target_T": "K"} + return super().unit_info | {"Target_T": JANUS_UNITS["temperature"]} @property def default_formats(self) -> dict[str, str]: @@ -1502,7 +1521,7 @@ def unit_info(self) -> dict[str, str]: dict[str, str] Units attached to statistical properties. """ - return super().unit_info | {"Target_T": "K"} + return super().unit_info | {"Target_T": JANUS_UNITS["temperature"]} @property def default_formats(self) -> dict[str, str]: diff --git a/janus_core/calculations/phonons.py b/janus_core/calculations/phonons.py index fd936fb6..7f3e6062 100644 --- a/janus_core/calculations/phonons.py +++ b/janus_core/calculations/phonons.py @@ -435,6 +435,8 @@ def calc_force_constants( if self.tracker: self.tracker.start_task("Phonon calculation") + self._set_info_units() + cell = self._ASE_to_PhonopyAtoms(self.struct) if len(self.supercell) == 3: diff --git a/janus_core/calculations/single_point.py b/janus_core/calculations/single_point.py index 953f0070..806bf622 100644 --- a/janus_core/calculations/single_point.py +++ b/janus_core/calculations/single_point.py @@ -321,6 +321,8 @@ def run(self) -> CalcResults: if self.tracker: self.tracker.start_task("Single point") + self._set_info_units(self.properties) + if "energy" in self.properties: self.results["energy"] = self._get_potential_energy() if "forces" in self.properties: diff --git a/janus_core/processing/observables.py b/janus_core/processing/observables.py index 52fe0b91..38a946f9 100644 --- a/janus_core/processing/observables.py +++ b/janus_core/processing/observables.py @@ -5,13 +5,16 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING -from ase import Atoms, units +from ase import Atoms +from ase.units import create_units if TYPE_CHECKING: from janus_core.helpers.janus_types import SliceLike from janus_core.helpers.utils import slicelike_to_startstopstep +units = create_units("2014") + # pylint: disable=too-few-public-methods class Observable(ABC): diff --git a/tests/test_geomopt_cli.py b/tests/test_geomopt_cli.py index dd2769f6..69d9af27 100644 --- a/tests/test_geomopt_cli.py +++ b/tests/test_geomopt_cli.py @@ -740,3 +740,32 @@ def test_no_carbon(tmp_path): with open(summary_path, encoding="utf8") as file: geomopt_summary = yaml.safe_load(file) assert "emissions" not in geomopt_summary + + +def test_units(tmp_path): + """Test correct units are saved.""" + results_path = tmp_path / "NaCl-opt.extxyz" + log_path = tmp_path / "test.log" + summary_path = tmp_path / "summary.yml" + + result = runner.invoke( + app, + [ + "geomopt", + "--struct", + DATA_PATH / "NaCl.cif", + "--out", + results_path, + "--log", + log_path, + "--summary", + summary_path, + ], + ) + assert result.exit_code == 0 + + atoms = read(results_path) + expected_units = {"energy": "eV", "forces": "ev/Ang", "stress": "ev/Ang^3"} + assert "units" in atoms.info + for prop, units in expected_units.items(): + assert atoms.info["units"][prop] == units diff --git a/tests/test_md_cli.py b/tests/test_md_cli.py index 1423c6b5..1161ad46 100644 --- a/tests/test_md_cli.py +++ b/tests/test_md_cli.py @@ -108,6 +108,23 @@ def test_md(ensemble): assert "momenta" in atoms.arrays assert "masses" in atoms.arrays + expected_units = { + "time": "fs", + "real_time": "s", + "energy": "eV", + "forces": "ev/Ang", + "stress": "ev/Ang^3", + "temperature": "K", + "density": "g/cm^3", + "momenta": "(eV*u)^0.5", + } + if ensemble in ("nvt", "nvt-nh"): + expected_units["pressure"] = "GPa" + + assert "units" in atoms.info + for prop, units in expected_units.items(): + assert atoms.info["units"][prop] == units + finally: final_path.unlink(missing_ok=True) restart_path.unlink(missing_ok=True) @@ -153,7 +170,7 @@ def test_log(tmp_path): assert len(lines) == 22 # Test constant volume - assert lines[0].split(" | ")[8] == "Volume [A^3]" + assert lines[0].split(" | ")[8] == "Volume [Ang^3]" init_volume = float(lines[1].split()[8]) final_volume = float(lines[-1].split()[8]) assert init_volume == 179.406144 diff --git a/tests/test_singlepoint_cli.py b/tests/test_singlepoint_cli.py index 2476e5a7..099fe6b7 100644 --- a/tests/test_singlepoint_cli.py +++ b/tests/test_singlepoint_cli.py @@ -72,6 +72,11 @@ def test_singlepoint(): assert "system_name" in atoms.info assert atoms.info["system_name"] == "NaCl" + expected_units = {"energy": "eV", "forces": "ev/Ang", "stress": "ev/Ang^3"} + assert "units" in atoms.info + for prop, units in expected_units.items(): + assert atoms.info["units"][prop] == units + clear_log_handlers() @@ -399,6 +404,7 @@ def test_hessian(tmp_path): assert "mace_mp_hessian" in atoms.info assert "mace_stress" not in atoms.info assert atoms.info["mace_mp_hessian"].shape == (24, 8, 3) + assert atoms.info["units"]["hessian"] == "ev/Ang^2" def test_no_carbon(tmp_path):