Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add units #378

Merged
merged 11 commits into from
Jan 14, 2025
32 changes: 32 additions & 0 deletions janus_core/calculations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,20 @@
from janus_core.helpers.struct_io import input_structs
from janus_core.helpers.utils import FileNameMixin, none_to_dict

UNITS = {
"energy": "eV",
"forces": "ev/Ang",
"stress": "ev/Ang^3",
"hessian": "ev/Ang^2",
"time": "fs",
"real_time": "s",
"temperature": "K",
"pressure": "GPa",
"momenta": "(eV*u)^0.5",
"density": "g/cm^3",
"volume": "Ang^3",
}


class BaseCalculation(FileNameMixin):
"""
Expand Down Expand Up @@ -202,3 +216,21 @@ def __init__(
self.tracker = config_tracker(
self.logger, self.track_carbon, **self.tracker_kwargs
)

def _set_info_units(
self, keys: Sequence[str] = ("energy", "forces", "stress")
) -> None:
"""
Save units to structure info.

Parameters
----------
keys : Sequence
Keys for which to add units to structure info. Default is
("energy", "forces", "stress").
"""
if isinstance(self.struct, Sequence):
for image in self.struct:
image.info["units"] = {key: UNITS[key] for key in keys}
else:
self.struct.info["units"] = {key: UNITS[key] for key in keys}
2 changes: 2 additions & 0 deletions janus_core/calculations/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ def run(self) -> EoSResults:
Dictionary containing equation of state ASE object, and the fitted minimum
bulk modulus, volume, and energy.
"""
self._set_info_units()

if self.minimize:
if self.logger:
self.logger.info("Minimising initial structure")
Expand Down
2 changes: 2 additions & 0 deletions janus_core/calculations/geom_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ def run(self) -> None:
if self.tracker:
self.tracker.start_task("Geometry optimization")

self._set_info_units()

converged = self.dyn.run(fmax=self.fmax, steps=self.steps)

# Calculate current maximum force
Expand Down
61 changes: 40 additions & 21 deletions janus_core/calculations/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Any
from warnings import warn

from ase import Atoms, units
from ase import Atoms
from ase.geometry.analysis import Analysis
from ase.io import read
from ase.md.langevin import Langevin
Expand All @@ -23,9 +23,11 @@
ZeroRotation,
)
from ase.md.verlet import VelocityVerlet
from ase.units import create_units
import numpy as np
import yaml

from janus_core.calculations.base import UNITS as JANUS_UNITS
from janus_core.calculations.base import BaseCalculation
from janus_core.calculations.geom_opt import GeomOpt
from janus_core.helpers.janus_types import (
Expand All @@ -43,6 +45,7 @@
from janus_core.processing.correlator import Correlation
from janus_core.processing.post_process import compute_rdf, compute_vaf

units = create_units("2014")
ElliottKasoar marked this conversation as resolved.
Show resolved Hide resolved
DENS_FACT = (units.m / 1.0e2) ** 3 / units.mol


Expand Down Expand Up @@ -520,7 +523,7 @@ def _set_info(self) -> None:
"""Set time in fs, current dynamics step, and density to info."""
time = (self.offset * self.timestep + self.dyn.get_time()) / units.fs
step = self.offset + self.dyn.nsteps
self.dyn.atoms.info["time_fs"] = time
self.dyn.atoms.info["time"] = time
self.dyn.atoms.info["step"] = step
try:
density = (
Expand Down Expand Up @@ -769,7 +772,7 @@ def get_stats(self) -> dict[str, float]:
return {
"Step": self.dyn.atoms.info["step"],
"Real_Time": real_time.total_seconds(),
"Time": self.dyn.atoms.info["time_fs"],
"Time": self.dyn.atoms.info["time"],
"Epot/N": e_pot,
"EKin/N": e_kin,
"T": current_temp,
Expand Down Expand Up @@ -797,21 +800,21 @@ def unit_info(self) -> dict[str, str]:
"""
return {
"Step": None,
"Real_Time": "s",
"Time": "fs",
"Epot/N": "eV",
"EKin/N": "eV",
"T": "K",
"ETot/N": "eV",
"Density": "g/cm^3",
"Volume": "A^3",
"P": "GPa",
"Pxx": "GPa",
"Pyy": "GPa",
"Pzz": "GPa",
"Pyz": "GPa",
"Pxz": "GPa",
"Pxy": "GPa",
"Real_Time": JANUS_UNITS["real_time"],
"Time": JANUS_UNITS["time"],
"Epot/N": JANUS_UNITS["energy"],
"EKin/N": JANUS_UNITS["energy"],
"T": JANUS_UNITS["temperature"],
"ETot/N": JANUS_UNITS["energy"],
"Density": JANUS_UNITS["density"],
"Volume": JANUS_UNITS["volume"],
"P": JANUS_UNITS["pressure"],
"Pxx": JANUS_UNITS["pressure"],
"Pyy": JANUS_UNITS["pressure"],
"Pzz": JANUS_UNITS["pressure"],
"Pyz": JANUS_UNITS["pressure"],
"Pxz": JANUS_UNITS["pressure"],
"Pxy": JANUS_UNITS["pressure"],
}

@property
Expand Down Expand Up @@ -1021,6 +1024,19 @@ def _write_restart(self) -> None:

def run(self) -> None:
"""Run molecular dynamics simulation and/or temperature ramp."""
unit_keys = (
"energy",
"forces",
"stress",
"time",
"real_time",
"temperature",
"pressure",
"density",
"momenta",
)
self._set_info_units(unit_keys)

if not self.restart:
if self.minimize:
self._optimize_structure()
Expand Down Expand Up @@ -1262,7 +1278,10 @@ def unit_info(self) -> dict[str, str]:
dict[str, str]
Units attached to statistical properties.
"""
return super().unit_info | {"Target_P": "GPa", "Target_T": "K"}
return super().unit_info | {
"Target_P": JANUS_UNITS["pressure"],
"Target_T": JANUS_UNITS["temperature"],
}

@property
def default_formats(self) -> dict[str, str]:
Expand Down Expand Up @@ -1359,7 +1378,7 @@ def unit_info(self) -> dict[str, str]:
dict[str, str]
Units attached to statistical properties.
"""
return super().unit_info | {"Target_T": "K"}
return super().unit_info | {"Target_T": JANUS_UNITS["temperature"]}

@property
def default_formats(self) -> dict[str, str]:
Expand Down Expand Up @@ -1502,7 +1521,7 @@ def unit_info(self) -> dict[str, str]:
dict[str, str]
Units attached to statistical properties.
"""
return super().unit_info | {"Target_T": "K"}
return super().unit_info | {"Target_T": JANUS_UNITS["temperature"]}

@property
def default_formats(self) -> dict[str, str]:
Expand Down
2 changes: 2 additions & 0 deletions janus_core/calculations/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,8 @@ def calc_force_constants(
if self.tracker:
self.tracker.start_task("Phonon calculation")

self._set_info_units()

cell = self._ASE_to_PhonopyAtoms(self.struct)

if len(self.supercell) == 3:
Expand Down
2 changes: 2 additions & 0 deletions janus_core/calculations/single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ def run(self) -> CalcResults:
if self.tracker:
self.tracker.start_task("Single point")

self._set_info_units(self.properties)

if "energy" in self.properties:
self.results["energy"] = self._get_potential_energy()
if "forces" in self.properties:
Expand Down
5 changes: 4 additions & 1 deletion janus_core/processing/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

from ase import Atoms, units
from ase import Atoms
from ase.units import create_units

if TYPE_CHECKING:
from janus_core.helpers.janus_types import SliceLike

from janus_core.helpers.utils import slicelike_to_startstopstep

units = create_units("2014")


# pylint: disable=too-few-public-methods
class Observable(ABC):
Expand Down
29 changes: 29 additions & 0 deletions tests/test_geomopt_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,3 +740,32 @@ def test_no_carbon(tmp_path):
with open(summary_path, encoding="utf8") as file:
geomopt_summary = yaml.safe_load(file)
assert "emissions" not in geomopt_summary


def test_units(tmp_path):
"""Test correct units are saved."""
results_path = tmp_path / "NaCl-opt.extxyz"
log_path = tmp_path / "test.log"
summary_path = tmp_path / "summary.yml"

result = runner.invoke(
app,
[
"geomopt",
"--struct",
DATA_PATH / "NaCl.cif",
"--out",
results_path,
"--log",
log_path,
"--summary",
summary_path,
],
)
assert result.exit_code == 0

atoms = read(results_path)
expected_units = {"energy": "eV", "forces": "ev/Ang", "stress": "ev/Ang^3"}
assert "units" in atoms.info
for prop, units in expected_units.items():
assert atoms.info["units"][prop] == units
19 changes: 18 additions & 1 deletion tests/test_md_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,23 @@ def test_md(ensemble):
assert "momenta" in atoms.arrays
assert "masses" in atoms.arrays

expected_units = {
"time": "fs",
"real_time": "s",
"energy": "eV",
"forces": "ev/Ang",
"stress": "ev/Ang^3",
"temperature": "K",
"density": "g/cm^3",
"momenta": "(eV*u)^0.5",
}
if ensemble in ("nvt", "nvt-nh"):
expected_units["pressure"] = "GPa"

assert "units" in atoms.info
for prop, units in expected_units.items():
assert atoms.info["units"][prop] == units

finally:
final_path.unlink(missing_ok=True)
restart_path.unlink(missing_ok=True)
Expand Down Expand Up @@ -153,7 +170,7 @@ def test_log(tmp_path):
assert len(lines) == 22

# Test constant volume
assert lines[0].split(" | ")[8] == "Volume [A^3]"
assert lines[0].split(" | ")[8] == "Volume [Ang^3]"
init_volume = float(lines[1].split()[8])
final_volume = float(lines[-1].split()[8])
assert init_volume == 179.406144
Expand Down
6 changes: 6 additions & 0 deletions tests/test_singlepoint_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def test_singlepoint():
assert "system_name" in atoms.info
assert atoms.info["system_name"] == "NaCl"

expected_units = {"energy": "eV", "forces": "ev/Ang", "stress": "ev/Ang^3"}
assert "units" in atoms.info
for prop, units in expected_units.items():
assert atoms.info["units"][prop] == units

clear_log_handlers()


Expand Down Expand Up @@ -399,6 +404,7 @@ def test_hessian(tmp_path):
assert "mace_mp_hessian" in atoms.info
assert "mace_stress" not in atoms.info
assert atoms.info["mace_mp_hessian"].shape == (24, 8, 3)
assert atoms.info["units"]["hessian"] == "ev/Ang^2"


def test_no_carbon(tmp_path):
Expand Down
Loading