Skip to content

Commit

Permalink
Adjust file handling to be a mixin (#168)
Browse files Browse the repository at this point in the history
* Adjust file handling

---------

Co-authored-by: Alin Marin Elena <[email protected]>
Co-authored-by: ElliottKasoar <[email protected]>
  • Loading branch information
3 people authored May 30, 2024
1 parent d084433 commit 2dcb1c8
Show file tree
Hide file tree
Showing 6 changed files with 307 additions and 95 deletions.
75 changes: 33 additions & 42 deletions janus_core/calculations/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
from janus_core.calculations.geom_opt import optimize
from janus_core.helpers.janus_types import Ensembles, PathLike
from janus_core.helpers.log import config_logger
from janus_core.helpers.utils import FileNameMixin

DENS_FACT = (units.m / 1.0e2) ** 3 / units.mol


class MolecularDynamics: # pylint: disable=too-many-instance-attributes
class MolecularDynamics(FileNameMixin): # pylint: disable=too-many-instance-attributes
"""
Configure shared molecular dynamics simulation options.
Expand Down Expand Up @@ -237,7 +238,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta
Default is None.
"""
self.struct = struct
self.struct_name = struct_name
self.timestep = timestep * units.fs
self.steps = steps
self.temp = temp
Expand All @@ -247,7 +247,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta
self.rescale_velocities = rescale_velocities
self.remove_rot = remove_rot
self.rescale_every = rescale_every
self.file_prefix = file_prefix
self.restart = restart
self.restart_stem = restart_stem
self.restart_every = restart_every
Expand All @@ -267,6 +266,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta
self.ensemble = ensemble
self.seed = seed

FileNameMixin.__init__(self, struct, struct_name, file_prefix, ensemble)

self.log_kwargs = (
log_kwargs if log_kwargs else {}
) # pylint: disable=duplicate-code
Expand Down Expand Up @@ -317,11 +318,16 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta
self.dyn = None
self.n_atoms = len(self.struct)

# Infer names for structure and if not specified
if not self.struct_name:
self.struct_name = self.struct.get_chemical_formula()

self.configure_filenames()
self.stats_file = self._build_filename(
"stats.dat",
self._parameter_prefix if file_prefix is None else "",
filename=self.stats_file,
)
self.traj_file = self._build_filename(
"traj.xyz",
self._parameter_prefix if file_prefix is None else "",
filename=self.traj_file,
)

self.offset = 0

Expand Down Expand Up @@ -388,12 +394,12 @@ def _parameter_prefix(self) -> str:

temperature_prefix = ""
if self.temp_start is not None and self.temp_end is not None:
temperature_prefix = f"-T{self.temp_start}-T{self.temp_end}"
temperature_prefix += f"-T{self.temp_start}-T{self.temp_end}"

if self.steps > 0:
temperature_prefix += f"-T{self.temp}"

return temperature_prefix
return temperature_prefix.lstrip("-")

@property
def _final_file(self) -> str:
Expand All @@ -406,10 +412,9 @@ def _final_file(self) -> str:
File name for final state.
"""

if not self.restart_stem:
return f"{self.file_prefix}-T{self.temp}-final.xyz"
# respect the users choice
return f"{self.restart_stem}-T{self.temp}-final.xyz"
return self._build_filename(
"final.xyz", f"T{self.temp}", prefix_override=self.restart_stem
)

@property
def _restart_file(self) -> str:
Expand All @@ -422,26 +427,9 @@ def _restart_file(self) -> str:
File name for restart files.
"""
step = self.offset + self.dyn.nsteps
if not self.restart_stem:
return f"{self.file_prefix}-T{self.temp}-res-{step}.xyz"
return f"{self.restart_stem}-T{self.temp}-res-{step}.xyz"

def configure_filenames(self) -> None:
"""Setup filenames for output files."""

if not self.file_prefix:
self.file_prefix = f"{self.struct_name}-{self.ensemble}"
data_prefix = f"{self.file_prefix}{self._parameter_prefix}"
else:
data_prefix = f"{self.file_prefix}"
if not self.restart_stem:
self.restart_stem = f"{self.file_prefix}"

if not self.stats_file:
self.stats_file = f"{data_prefix}-stats.dat"

if not self.traj_file:
self.traj_file = f"{data_prefix}-traj.xyz"
return self._build_filename(
f"res-{step}.xyz", f"T{self.temp}", prefix_override=self.restart_stem
)

@staticmethod
def get_log_header() -> str:
Expand Down Expand Up @@ -773,10 +761,10 @@ def _final_file(self) -> str:
File name for final state, includes pressure.
"""

pressure = f"-p{self.pressure}" if not isinstance(self, NVT_NH) else ""
if not self.restart_stem:
return f"{self.file_prefix}-T{self.temp}{pressure}-final.xyz"
return f"{self.restart_stem}-T{self.temp}{pressure}-final.xyz"
pressure = f"p{self.pressure}" if not isinstance(self, NVT_NH) else ""
return self._build_filename(
"final.xyz", f"T{self.temp}", pressure, prefix_override=self.restart_stem
)

@property
def _restart_file(self) -> str:
Expand All @@ -789,10 +777,13 @@ def _restart_file(self) -> str:
File name for restart file, includes pressure.
"""
step = self.offset + self.dyn.nsteps
pressure = f"-p{self.pressure}" if not isinstance(self, NVT_NH) else ""
if not self.restart_stem:
return f"{self.file_prefix}-T{self.temp}{pressure}-res-{step}.xyz"
return f"{self.restart_stem}-T{self.temp}{pressure}-res-{step}.xyz"
pressure = f"p{self.pressure}" if not isinstance(self, NVT_NH) else ""
return self._build_filename(
f"res-{step}.xyz",
f"T{self.temp}",
pressure,
prefix_override=self.restart_stem,
)

def get_log_stats(self) -> str:
"""
Expand Down
56 changes: 15 additions & 41 deletions janus_core/calculations/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from janus_core.calculations.geom_opt import optimize
from janus_core.helpers.janus_types import MaybeList, PathLike
from janus_core.helpers.log import config_logger
from janus_core.helpers.utils import none_to_dict
from janus_core.helpers.utils import FileNameMixin, none_to_dict


class Phonons: # pylint: disable=too-many-instance-attributes
class Phonons(FileNameMixin): # pylint: disable=too-many-instance-attributes
"""
Configure, perform phonon calculations and write out results.
Expand Down Expand Up @@ -127,15 +127,11 @@ def __init__( # pylint: disable=too-many-arguments,disable=too-many-locals
log_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to `config_logger`. Default is {}.
"""
FileNameMixin.__init__(self, struct, struct_name, file_prefix)

[minimize_kwargs, log_kwargs] = none_to_dict([minimize_kwargs, log_kwargs])

self.struct = struct
if struct_name:
self.struct_name = struct_name
else:
self.struct_name = self.struct.get_chemical_formula()

self.file_prefix = file_prefix if file_prefix else self.struct_name

# Ensure supercell is a valid list
self.supercell = [supercell] * 3 if isinstance(supercell, int) else supercell
Expand Down Expand Up @@ -163,28 +159,6 @@ def __init__( # pylint: disable=too-many-arguments,disable=too-many-locals
self.calc = self.struct.calc
self.results = {}

def _set_filename(
self, default_suffix: str, filename: Optional[PathLike] = None
) -> str:
"""
Set filename using the file prefix and suffix if not specified otherwise.
Parameters
----------
default_suffix : str
Default suffix to use if `filename` is not specified.
filename : Optional[PathLike]
Filename to use, if specified. Default is None.
Returns
-------
str
Filename specified, or default filename.
"""
if filename:
return filename
return f"{self.file_prefix}-{default_suffix}"

def calc_force_constants(self, write_results: bool = True) -> None:
"""
Calculate force constants and optionally write results.
Expand Down Expand Up @@ -271,7 +245,7 @@ def write_band_structure(
`file_prefix`.
"""

bands_file = self._set_filename("auto_bands.yml", bands_file)
bands_file = self._build_filename("auto_bands.yml", filename=bands_file)
self.results["phonon"].auto_band_structure(
write_yaml=write_bands,
filename=bands_file,
Expand All @@ -280,7 +254,7 @@ def write_band_structure(
)
if self.plot_to_file:
bplt = self.results["phonon"].plot_band_structure()
plot_file = self._set_filename("auto_bands.svg", plot_file)
plot_file = self._build_filename("auto_bands.svg", filename=plot_file)
bplt.savefig(plot_file)

def write_force_constants(
Expand All @@ -305,9 +279,9 @@ def write_force_constants(
Name of hdf5 file to save force constants. Unused if `force_consts_to_hdf5`
is False. Default is inferred from `file_prefix`.
"""
phonopy_file = self._set_filename("phonopy.yml", phonopy_file)
force_consts_file = self._set_filename(
"force_constants.hdf5", force_consts_file
phonopy_file = self._build_filename("phonopy.yml", filename=phonopy_file)
force_consts_file = self._build_filename(
"force_constants.hdf5", filename=force_consts_file
)

phonon = self.results["phonon"]
Expand Down Expand Up @@ -360,7 +334,7 @@ def write_thermal_props(self, filename: Optional[PathLike] = None) -> None:
Name of data file to save thermal properties. Default is inferred from
`file_prefix`.
"""
filename = self._set_filename("thermal.dat", filename)
filename = self._build_filename("thermal.dat", filename=filename)

with open(filename, "w", encoding="utf8") as out:
temps = self.results["thermal_properties"]["temperatures"]
Expand Down Expand Up @@ -422,15 +396,15 @@ def write_dos(
Name of svg file to plot the band structure and DOS.
Default is inferred from `file_prefix`.
"""
filename = self._set_filename("dos.dat", filename)
filename = self._build_filename("dos.dat", filename=filename)
self.results["phonon"].total_dos.write(filename)
if self.plot_to_file:
bplt = self.results["phonon"].plot_total_dos()
plot_file = self._set_filename("dos.svg", plot_file)
plot_file = self._build_filename("dos.svg", filename=plot_file)
bplt.savefig(plot_file)

bplt = self.results["phonon"].plot_band_structure_and_dos()
plot_bs_file = self._set_filename("bs-dos.svg", plot_bs_file)
plot_bs_file = self._build_filename("bs-dos.svg", filename=plot_bs_file)
bplt.savefig(plot_bs_file)

def calc_pdos(
Expand Down Expand Up @@ -479,11 +453,11 @@ def write_pdos(
Name of svg file to plot the calculated PDOS. Default is inferred from
`file_prefix`.
"""
filename = self._set_filename("pdos.dat", filename)
filename = self._build_filename("pdos.dat", filename=filename)
self.results["phonon"].projected_dos.write(filename)
if self.plot_to_file:
bplt = self.results["phonon"].plot_projected_dos()
plot_file = self._set_filename("pdos.svg", plot_file)
plot_file = self._build_filename("pdos.svg", filename=plot_file)
bplt.savefig(plot_file)

# No magnetic moments considered
Expand Down
20 changes: 10 additions & 10 deletions janus_core/calculations/single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
)
from janus_core.helpers.log import config_logger
from janus_core.helpers.mlip_calculators import choose_calculator
from janus_core.helpers.utils import none_to_dict
from janus_core.helpers.utils import FileNameMixin, none_to_dict


class SinglePoint:
class SinglePoint(FileNameMixin):
"""
Prepare and perform single point calculations.
Expand Down Expand Up @@ -148,8 +148,8 @@ def __init__(
self.read_structure(**read_kwargs)
else:
self.struct = struct
if not self.struct_name:
self.struct_name = self.struct.get_chemical_formula()

FileNameMixin.__init__(self, self.struct, self.struct_name, None)

# Configure calculator
self.set_calculator(**calc_kwargs)
Expand All @@ -169,13 +169,13 @@ def read_structure(self, **kwargs) -> None:
**kwargs
Keyword arguments passed to ase.io.read.
"""
if self.struct_path:
self.struct = read(self.struct_path, **kwargs)
if not self.struct_name:
self.struct_name = Path(self.struct_path).stem
else:
if not self.struct_path:
raise ValueError("`struct_path` must be defined")

self.struct = read(self.struct_path, **kwargs)
if not self.struct_name:
self.struct_name = Path(self.struct_path).stem

def set_calculator(
self, read_kwargs: Optional[ASEReadArgs] = None, **kwargs
) -> None:
Expand Down Expand Up @@ -346,7 +346,7 @@ def run(

write_kwargs.setdefault(
"filename",
Path(f"./{self.struct_name}-results.xyz").absolute(),
self._build_filename("results.xyz").absolute(),
)

if self.logger:
Expand Down
Loading

0 comments on commit 2dcb1c8

Please sign in to comment.