From 2dcb1c80af385b11fe5d6c13ea8b29f018573ed3 Mon Sep 17 00:00:00 2001 From: Jacob Wilkins <46597752+oerc0122@users.noreply.github.com> Date: Thu, 30 May 2024 21:52:47 +0100 Subject: [PATCH] Adjust file handling to be a mixin (#168) * Adjust file handling --------- Co-authored-by: Alin Marin Elena Co-authored-by: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> --- janus_core/calculations/md.py | 75 ++++++------- janus_core/calculations/phonons.py | 56 +++------- janus_core/calculations/single_point.py | 20 ++-- janus_core/helpers/utils.py | 136 ++++++++++++++++++++++++ tests/test_filenamemixin.py | 111 +++++++++++++++++++ tests/test_geom_opt.py | 4 +- 6 files changed, 307 insertions(+), 95 deletions(-) create mode 100644 tests/test_filenamemixin.py diff --git a/janus_core/calculations/md.py b/janus_core/calculations/md.py index fd8d07ec..21fdd857 100644 --- a/janus_core/calculations/md.py +++ b/janus_core/calculations/md.py @@ -23,11 +23,12 @@ from janus_core.calculations.geom_opt import optimize from janus_core.helpers.janus_types import Ensembles, PathLike from janus_core.helpers.log import config_logger +from janus_core.helpers.utils import FileNameMixin DENS_FACT = (units.m / 1.0e2) ** 3 / units.mol -class MolecularDynamics: # pylint: disable=too-many-instance-attributes +class MolecularDynamics(FileNameMixin): # pylint: disable=too-many-instance-attributes """ Configure shared molecular dynamics simulation options. @@ -237,7 +238,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta Default is None. """ self.struct = struct - self.struct_name = struct_name self.timestep = timestep * units.fs self.steps = steps self.temp = temp @@ -247,7 +247,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta self.rescale_velocities = rescale_velocities self.remove_rot = remove_rot self.rescale_every = rescale_every - self.file_prefix = file_prefix self.restart = restart self.restart_stem = restart_stem self.restart_every = restart_every @@ -267,6 +266,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta self.ensemble = ensemble self.seed = seed + FileNameMixin.__init__(self, struct, struct_name, file_prefix, ensemble) + self.log_kwargs = ( log_kwargs if log_kwargs else {} ) # pylint: disable=duplicate-code @@ -317,11 +318,16 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta self.dyn = None self.n_atoms = len(self.struct) - # Infer names for structure and if not specified - if not self.struct_name: - self.struct_name = self.struct.get_chemical_formula() - - self.configure_filenames() + self.stats_file = self._build_filename( + "stats.dat", + self._parameter_prefix if file_prefix is None else "", + filename=self.stats_file, + ) + self.traj_file = self._build_filename( + "traj.xyz", + self._parameter_prefix if file_prefix is None else "", + filename=self.traj_file, + ) self.offset = 0 @@ -388,12 +394,12 @@ def _parameter_prefix(self) -> str: temperature_prefix = "" if self.temp_start is not None and self.temp_end is not None: - temperature_prefix = f"-T{self.temp_start}-T{self.temp_end}" + temperature_prefix += f"-T{self.temp_start}-T{self.temp_end}" if self.steps > 0: temperature_prefix += f"-T{self.temp}" - return temperature_prefix + return temperature_prefix.lstrip("-") @property def _final_file(self) -> str: @@ -406,10 +412,9 @@ def _final_file(self) -> str: File name for final state. """ - if not self.restart_stem: - return f"{self.file_prefix}-T{self.temp}-final.xyz" - # respect the users choice - return f"{self.restart_stem}-T{self.temp}-final.xyz" + return self._build_filename( + "final.xyz", f"T{self.temp}", prefix_override=self.restart_stem + ) @property def _restart_file(self) -> str: @@ -422,26 +427,9 @@ def _restart_file(self) -> str: File name for restart files. """ step = self.offset + self.dyn.nsteps - if not self.restart_stem: - return f"{self.file_prefix}-T{self.temp}-res-{step}.xyz" - return f"{self.restart_stem}-T{self.temp}-res-{step}.xyz" - - def configure_filenames(self) -> None: - """Setup filenames for output files.""" - - if not self.file_prefix: - self.file_prefix = f"{self.struct_name}-{self.ensemble}" - data_prefix = f"{self.file_prefix}{self._parameter_prefix}" - else: - data_prefix = f"{self.file_prefix}" - if not self.restart_stem: - self.restart_stem = f"{self.file_prefix}" - - if not self.stats_file: - self.stats_file = f"{data_prefix}-stats.dat" - - if not self.traj_file: - self.traj_file = f"{data_prefix}-traj.xyz" + return self._build_filename( + f"res-{step}.xyz", f"T{self.temp}", prefix_override=self.restart_stem + ) @staticmethod def get_log_header() -> str: @@ -773,10 +761,10 @@ def _final_file(self) -> str: File name for final state, includes pressure. """ - pressure = f"-p{self.pressure}" if not isinstance(self, NVT_NH) else "" - if not self.restart_stem: - return f"{self.file_prefix}-T{self.temp}{pressure}-final.xyz" - return f"{self.restart_stem}-T{self.temp}{pressure}-final.xyz" + pressure = f"p{self.pressure}" if not isinstance(self, NVT_NH) else "" + return self._build_filename( + "final.xyz", f"T{self.temp}", pressure, prefix_override=self.restart_stem + ) @property def _restart_file(self) -> str: @@ -789,10 +777,13 @@ def _restart_file(self) -> str: File name for restart file, includes pressure. """ step = self.offset + self.dyn.nsteps - pressure = f"-p{self.pressure}" if not isinstance(self, NVT_NH) else "" - if not self.restart_stem: - return f"{self.file_prefix}-T{self.temp}{pressure}-res-{step}.xyz" - return f"{self.restart_stem}-T{self.temp}{pressure}-res-{step}.xyz" + pressure = f"p{self.pressure}" if not isinstance(self, NVT_NH) else "" + return self._build_filename( + f"res-{step}.xyz", + f"T{self.temp}", + pressure, + prefix_override=self.restart_stem, + ) def get_log_stats(self) -> str: """ diff --git a/janus_core/calculations/phonons.py b/janus_core/calculations/phonons.py index 5ff95537..744a8591 100644 --- a/janus_core/calculations/phonons.py +++ b/janus_core/calculations/phonons.py @@ -11,10 +11,10 @@ from janus_core.calculations.geom_opt import optimize from janus_core.helpers.janus_types import MaybeList, PathLike from janus_core.helpers.log import config_logger -from janus_core.helpers.utils import none_to_dict +from janus_core.helpers.utils import FileNameMixin, none_to_dict -class Phonons: # pylint: disable=too-many-instance-attributes +class Phonons(FileNameMixin): # pylint: disable=too-many-instance-attributes """ Configure, perform phonon calculations and write out results. @@ -127,15 +127,11 @@ def __init__( # pylint: disable=too-many-arguments,disable=too-many-locals log_kwargs : Optional[dict[str, Any]] Keyword arguments to pass to `config_logger`. Default is {}. """ + FileNameMixin.__init__(self, struct, struct_name, file_prefix) + [minimize_kwargs, log_kwargs] = none_to_dict([minimize_kwargs, log_kwargs]) self.struct = struct - if struct_name: - self.struct_name = struct_name - else: - self.struct_name = self.struct.get_chemical_formula() - - self.file_prefix = file_prefix if file_prefix else self.struct_name # Ensure supercell is a valid list self.supercell = [supercell] * 3 if isinstance(supercell, int) else supercell @@ -163,28 +159,6 @@ def __init__( # pylint: disable=too-many-arguments,disable=too-many-locals self.calc = self.struct.calc self.results = {} - def _set_filename( - self, default_suffix: str, filename: Optional[PathLike] = None - ) -> str: - """ - Set filename using the file prefix and suffix if not specified otherwise. - - Parameters - ---------- - default_suffix : str - Default suffix to use if `filename` is not specified. - filename : Optional[PathLike] - Filename to use, if specified. Default is None. - - Returns - ------- - str - Filename specified, or default filename. - """ - if filename: - return filename - return f"{self.file_prefix}-{default_suffix}" - def calc_force_constants(self, write_results: bool = True) -> None: """ Calculate force constants and optionally write results. @@ -271,7 +245,7 @@ def write_band_structure( `file_prefix`. """ - bands_file = self._set_filename("auto_bands.yml", bands_file) + bands_file = self._build_filename("auto_bands.yml", filename=bands_file) self.results["phonon"].auto_band_structure( write_yaml=write_bands, filename=bands_file, @@ -280,7 +254,7 @@ def write_band_structure( ) if self.plot_to_file: bplt = self.results["phonon"].plot_band_structure() - plot_file = self._set_filename("auto_bands.svg", plot_file) + plot_file = self._build_filename("auto_bands.svg", filename=plot_file) bplt.savefig(plot_file) def write_force_constants( @@ -305,9 +279,9 @@ def write_force_constants( Name of hdf5 file to save force constants. Unused if `force_consts_to_hdf5` is False. Default is inferred from `file_prefix`. """ - phonopy_file = self._set_filename("phonopy.yml", phonopy_file) - force_consts_file = self._set_filename( - "force_constants.hdf5", force_consts_file + phonopy_file = self._build_filename("phonopy.yml", filename=phonopy_file) + force_consts_file = self._build_filename( + "force_constants.hdf5", filename=force_consts_file ) phonon = self.results["phonon"] @@ -360,7 +334,7 @@ def write_thermal_props(self, filename: Optional[PathLike] = None) -> None: Name of data file to save thermal properties. Default is inferred from `file_prefix`. """ - filename = self._set_filename("thermal.dat", filename) + filename = self._build_filename("thermal.dat", filename=filename) with open(filename, "w", encoding="utf8") as out: temps = self.results["thermal_properties"]["temperatures"] @@ -422,15 +396,15 @@ def write_dos( Name of svg file to plot the band structure and DOS. Default is inferred from `file_prefix`. """ - filename = self._set_filename("dos.dat", filename) + filename = self._build_filename("dos.dat", filename=filename) self.results["phonon"].total_dos.write(filename) if self.plot_to_file: bplt = self.results["phonon"].plot_total_dos() - plot_file = self._set_filename("dos.svg", plot_file) + plot_file = self._build_filename("dos.svg", filename=plot_file) bplt.savefig(plot_file) bplt = self.results["phonon"].plot_band_structure_and_dos() - plot_bs_file = self._set_filename("bs-dos.svg", plot_bs_file) + plot_bs_file = self._build_filename("bs-dos.svg", filename=plot_bs_file) bplt.savefig(plot_bs_file) def calc_pdos( @@ -479,11 +453,11 @@ def write_pdos( Name of svg file to plot the calculated PDOS. Default is inferred from `file_prefix`. """ - filename = self._set_filename("pdos.dat", filename) + filename = self._build_filename("pdos.dat", filename=filename) self.results["phonon"].projected_dos.write(filename) if self.plot_to_file: bplt = self.results["phonon"].plot_projected_dos() - plot_file = self._set_filename("pdos.svg", plot_file) + plot_file = self._build_filename("pdos.svg", filename=plot_file) bplt.savefig(plot_file) # No magnetic moments considered diff --git a/janus_core/calculations/single_point.py b/janus_core/calculations/single_point.py index f41f1ec2..e09b50c6 100644 --- a/janus_core/calculations/single_point.py +++ b/janus_core/calculations/single_point.py @@ -19,10 +19,10 @@ ) from janus_core.helpers.log import config_logger from janus_core.helpers.mlip_calculators import choose_calculator -from janus_core.helpers.utils import none_to_dict +from janus_core.helpers.utils import FileNameMixin, none_to_dict -class SinglePoint: +class SinglePoint(FileNameMixin): """ Prepare and perform single point calculations. @@ -148,8 +148,8 @@ def __init__( self.read_structure(**read_kwargs) else: self.struct = struct - if not self.struct_name: - self.struct_name = self.struct.get_chemical_formula() + + FileNameMixin.__init__(self, self.struct, self.struct_name, None) # Configure calculator self.set_calculator(**calc_kwargs) @@ -169,13 +169,13 @@ def read_structure(self, **kwargs) -> None: **kwargs Keyword arguments passed to ase.io.read. """ - if self.struct_path: - self.struct = read(self.struct_path, **kwargs) - if not self.struct_name: - self.struct_name = Path(self.struct_path).stem - else: + if not self.struct_path: raise ValueError("`struct_path` must be defined") + self.struct = read(self.struct_path, **kwargs) + if not self.struct_name: + self.struct_name = Path(self.struct_path).stem + def set_calculator( self, read_kwargs: Optional[ASEReadArgs] = None, **kwargs ) -> None: @@ -346,7 +346,7 @@ def run( write_kwargs.setdefault( "filename", - Path(f"./{self.struct_name}-results.xyz").absolute(), + self._build_filename("results.xyz").absolute(), ) if self.logger: diff --git a/janus_core/helpers/utils.py b/janus_core/helpers/utils.py index b5ce8b27..088aab64 100644 --- a/janus_core/helpers/utils.py +++ b/janus_core/helpers/utils.py @@ -1,11 +1,147 @@ """Utility functions for janus_core.""" +from abc import ABC from pathlib import Path from typing import Optional from ase import Atoms from spglib import get_spacegroup +from janus_core.helpers.janus_types import PathLike + + +class FileNameMixin(ABC): # pylint: disable=too-few-public-methods + """ + Provide mixin functions for standard filename handling. + + Parameters + ---------- + struct : Atoms + Structure from which to derive the default name if struct_name not provided. + struct_name : Optional[str] + Struct name to use. + file_prefix : Optional[PathLike] + Default prefix to use. + *additional + Components to add to file_prefix (joined by hyphens). + + Methods + ------- + _get_default_struct_name(struct, struct_name) + Return the name from the provided struct_name or generate from struct. + _get_default_prefix(file_prefix, struct_name) + Return a prefix from the provided file_prefix or from struct_name. + _build_filename(suffix, *additional, filename, prefix_override) + Return a standard format filename if filename not provided. + """ + + def __init__( + self, + struct: Atoms, + struct_name: Optional[str], + file_prefix: Optional[PathLike], + *additional, + ): + """ + Provide mixin functions for standard filename handling. + + Parameters + ---------- + struct : Atoms + Structure from which to derive the default name if struct_name not provided. + struct_name : Optional[str] + Struct name to use. + file_prefix : Optional[PathLike] + Default prefix to use. + *additional + Components to add to file_prefix (joined by hyphens). + """ + self.struct_name = self._get_default_struct_name(struct, struct_name) + + self.file_prefix = Path( + self._get_default_prefix(file_prefix, self.struct_name, *additional) + ) + + @staticmethod + def _get_default_struct_name(struct: Atoms, struct_name: Optional[str]) -> str: + """ + Determine the default struct name from the structure or provided struct_name. + + Parameters + ---------- + struct : Atoms + Structure of system. + struct_name : Optional[str] + Name of structure. + + Returns + ------- + str + Structure name. + """ + + if struct_name is not None: + return struct_name + return struct.get_chemical_formula() + + @staticmethod + def _get_default_prefix( + file_prefix: Optional[PathLike], struct_name: str, *additional + ) -> str: + """ + Determine the default prefix from the structure name or provided file_prefix. + + Parameters + ---------- + file_prefix : str + Given file_prefix. + struct_name : str + Name of structure. + *additional + Components to add to file_prefix (joined by hyphens). + + Returns + ------- + str + File prefix. + """ + if file_prefix is not None: + return str(file_prefix) + return "-".join((struct_name, *additional)) + + def _build_filename( + self, + suffix: str, + *additional, + filename: Optional[PathLike] = None, + prefix_override: Optional[str] = None, + ) -> Path: + """ + Set filename using the file prefix and suffix if not specified otherwise. + + Parameters + ---------- + suffix : str + Default suffix to use if `filename` is not specified. + *additional + Extra components to add to suffix (joined with hyphens). + filename : Optional[PathLike] + Filename to use, if specified. Default is None. + prefix_override : Optional[str] + Replace file_prefix if not None. + + Returns + ------- + Path + Filename specified, or default filename. + """ + if filename: + return Path(filename) + prefix = ( + prefix_override if prefix_override is not None else str(self.file_prefix) + ) + return Path("-".join((prefix, *filter(None, additional), suffix))) + def spacegroup( struct: Atoms, sym_tolerance: float = 0.001, angle_tolerance: float = -1.0 diff --git a/tests/test_filenamemixin.py b/tests/test_filenamemixin.py new file mode 100644 index 00000000..5c333216 --- /dev/null +++ b/tests/test_filenamemixin.py @@ -0,0 +1,111 @@ +"""Test FileNameMixin functions.""" + +from pathlib import Path + +from ase.io import read +import pytest + +from janus_core.helpers.utils import FileNameMixin + +DATA_PATH = Path(__file__).parent / "data" +STRUCT = read(DATA_PATH / "benzene.xyz") + + +class DummyFileHandler(FileNameMixin): # pylint: disable=too-few-public-methods + """Used for testing FileNameMixin methods.""" + + def build_filename(self, *args, **kwargs): + """ + Expose _build_filename publicly. + """ + return self._build_filename(*args, **kwargs) + + +@pytest.mark.parametrize( + "params,struct_name,file_prefix", + ( + # Defaults to structure atoms from ASE + ((STRUCT, None, None), "C6H6", "C6H6"), + # Passing structure name sets file_prefix + ((STRUCT, "benzene", None), "benzene", "benzene"), + # file_prefix just sets itself + ((STRUCT, None, "benzene"), "C6H6", "benzene"), + ((STRUCT, "benzene", "cake"), "benzene", "cake"), + # file_prefix ignores additional + ((STRUCT, "benzene", "benzene", "wowzers"), "benzene", "benzene"), + # Additional only applies where no file_prefix + ((STRUCT, "benzene", None, "wowzers"), "benzene", "benzene-wowzers"), + ((STRUCT, None, None, "wowzers"), "C6H6", "C6H6-wowzers"), + ), +) +def test_file_name_mixin_init(params, struct_name, file_prefix): + """Test various options for initializing the mixin.""" + file_mix = DummyFileHandler(*params) + + assert file_mix.struct_name == struct_name + assert file_mix.file_prefix == Path(file_prefix) + + +@pytest.mark.parametrize( + "mixin_params,file_args,file_kwargs,file_name", + ( + ((STRUCT, None, None), ("data.xyz",), {}, "C6H6-data.xyz"), + ((STRUCT, "benzene", None), ("data.xyz",), {}, "benzene-data.xyz"), + ((STRUCT, None, "benzene"), ("data.xyz",), {}, "benzene-data.xyz"), + ((STRUCT, "benzene", "benzene"), ("data.xyz",), {}, "benzene-data.xyz"), + ((STRUCT, "benzene", "benzene"), ("data.xyz",), {}, "benzene-data.xyz"), + ( + (STRUCT, "benzene", "benzene", "wowzers"), + ("data.xyz",), + {}, + "benzene-data.xyz", + ), + ((STRUCT, None, "benzene", "wowzers"), ("data.xyz",), {}, "benzene-data.xyz"), + ((STRUCT, None, None, "wowzers"), ("data.xyz",), {}, "C6H6-wowzers-data.xyz"), + # Additional stacks with base + ( + (STRUCT, None, None, "wowzers"), + ("data.xyz", "beef"), + {}, + "C6H6-wowzers-beef-data.xyz", + ), + # Prefix override ignores class options + ( + (STRUCT, None, None, "wowzers"), + ("data.xyz",), + {"prefix_override": "beef"}, + "beef-data.xyz", + ), + # But not additional + ( + (STRUCT, None, None, "wowzers"), + ("data.xyz", "tasty"), + {"prefix_override": "beef"}, + "beef-tasty-data.xyz", + ), + # Filename overrides everything + ( + (STRUCT, None, None, "wowzers"), + ("data.xyz",), + {"filename": "hello.xyz"}, + "hello.xyz", + ), + ( + (STRUCT, None, None, "wowzers"), + ("data.xyz",), + {"prefix_override": "beef", "filename": "hello.xyz"}, + "hello.xyz", + ), + ( + (STRUCT, None, None, "wowzers"), + ("data.xyz", "tasty"), + {"prefix_override": "beef", "filename": "hello.xyz"}, + "hello.xyz", + ), + ), +) +def test_file_name_mixin_build(mixin_params, file_args, file_kwargs, file_name): + """Test building the filename for mixins.""" + file_mix = DummyFileHandler(*mixin_params) + + assert file_mix.build_filename(*file_args, **file_kwargs) == Path(file_name) diff --git a/tests/test_geom_opt.py b/tests/test_geom_opt.py index 24225309..10f9f946 100644 --- a/tests/test_geom_opt.py +++ b/tests/test_geom_opt.py @@ -121,13 +121,13 @@ def test_missing_traj_kwarg(tmp_path): def test_hydrostatic_strain(): """Test setting hydrostatic strain for filter.""" single_point_1 = SinglePoint( - struct_path="./tests/data/NaCl-deformed.cif", + struct_path=DATA_PATH / "NaCl-deformed.cif", architecture="mace", calc_kwargs={"model": MODEL_PATH}, ) single_point_2 = SinglePoint( - struct_path="./tests/data/NaCl-deformed.cif", + struct_path=DATA_PATH / "NaCl-deformed.cif", architecture="mace", calc_kwargs={"model": MODEL_PATH}, )