Skip to content

Commit

Permalink
Add post-processing routines to MD
Browse files Browse the repository at this point in the history
  • Loading branch information
oerc0122 committed May 31, 2024
1 parent 2dcb1c8 commit 5a94927
Show file tree
Hide file tree
Showing 7 changed files with 352 additions and 8 deletions.
10 changes: 10 additions & 0 deletions docs/source/apidoc/janus_core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,16 @@ janus\_core.helpers.descriptors module
:undoc-members:
:show-inheritance:

janus\_core.helpers.post_process module
---------------------------------------

.. automodule:: janus_core.helpers.post_process
:members:
:special-members:
:private-members:
:undoc-members:
:show-inheritance:

janus\_core.helpers.train module
--------------------------------

Expand Down
81 changes: 78 additions & 3 deletions janus_core/calculations/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from warnings import warn

from ase import Atoms, units
from ase.io import write
from ase.io import read, write
from ase.md.langevin import Langevin
from ase.md.npt import NPT as ASE_NPT
from ase.md.velocitydistribution import (
Expand All @@ -18,11 +18,21 @@
ZeroRotation,
)
from ase.md.verlet import VelocityVerlet

try:
from ase.geometry.analysis import Analysis

ASE_GEOMETRY = True
except ImportError:

ASE_GEOMETRY = False

import numpy as np

from janus_core.calculations.geom_opt import optimize
from janus_core.helpers.janus_types import Ensembles, PathLike
from janus_core.helpers.janus_types import Ensembles, PathLike, PostProcessKwargs
from janus_core.helpers.log import config_logger
from janus_core.helpers.post_process import compute_rdf, compute_vaf
from janus_core.helpers.utils import FileNameMixin

DENS_FACT = (units.m / 1.0e2) ** 3 / units.mol
Expand Down Expand Up @@ -97,6 +107,8 @@ class MolecularDynamics(FileNameMixin): # pylint: disable=too-many-instance-att
heating.
temp_time : Optional[float]
Time between heating steps, in fs. Default is None, which disables heating.
post_process_kwargs : Optional[dict[str,Any]]
Keyword arguments to control post-processing operations.
log_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to log config. Default is None.
seed : Optional[int]
Expand Down Expand Up @@ -157,6 +169,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta
temp_end: Optional[float] = None,
temp_step: Optional[float] = None,
temp_time: Optional[float] = None,
post_process_kwargs: Optional[PostProcessKwargs] = None,
log_kwargs: Optional[dict[str, Any]] = None,
seed: Optional[int] = None,
) -> None:
Expand Down Expand Up @@ -231,6 +244,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta
disables heating.
temp_time : Optional[float]
Time between heating steps, in fs. Default is None, which disables heating.
post_process_kwargs : Optional[PostProcessKwargs]
Keyword arguments to control post-processing operations.
log_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to log config. Default is None.
seed : Optional[int]
Expand Down Expand Up @@ -262,6 +277,9 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta
self.temp_end = temp_end
self.temp_step = temp_step
self.temp_time = temp_time * units.fs if temp_time else None
self.post_process_kwargs = (
post_process_kwargs if post_process_kwargs is not None else {}
)
self.log_kwargs = log_kwargs
self.ensemble = ensemble
self.seed = seed
Expand Down Expand Up @@ -315,7 +333,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta

self.minimize_kwargs = minimize_kwargs if minimize_kwargs else {}
self.restart_files = []
self.dyn = None
self.dyn: Langevin | VelocityVerlet | ASE_NPT
self.n_atoms = len(self.struct)

self.stats_file = self._build_filename(
Expand Down Expand Up @@ -542,6 +560,60 @@ def _write_final_state(self) -> None:
columns=["symbols", "positions", "momenta", "masses"],
)

def _post_process(self) -> None:
"""Compute properties after MD run."""
# Nothing to do
if not any(
self.post_process_kwargs.get(kwarg, None)
for kwarg in ("rdf_compute", "vaf_compute")
):
return

data = read(self.traj_file)

if ASE_GEOMETRY:
ana = Analysis(data)
else:
ana = None

param_pref = self._parameter_prefix if self.file_prefix is None else ""

if self.post_process_kwargs.get("rdf_compute", False):
base_name = self.post_process_kwargs.get("rdf_output_file", None)
rdf_args = {
name: self.post_process_kwargs.get(key, default)
for name, (key, default) in (
("rmax", ("rdf_rmax", 2.5)),
("nbins", ("rdf_nbins", 50)),
("elements", ("rdf_elements", None)),
)
}
slice_ = (
self.post_process_kwargs.get("rdf_start", 0),
self.post_process_kwargs.get("rdf_stop", 1),
self.post_process_kwargs.get("rdf_step", 1),
)

out_paths = [
self._build_filename(
"rdf.dat", param_pref, str(ind), prefix_override=base_name
)
for ind in range(*slice_)
]

rdf_args["index"] = slice_
compute_rdf(data, out_paths, ana, **rdf_args)

if self.post_process_kwargs.get("vaf_compute", False):

file_name = self.post_process_kwargs.get("vaf_output_file", None)
use_vel = self.post_process_kwargs.get("vaf_velocities", False)
fft = self.post_process_kwargs.get("vaf_fft", False)

out_path = self._build_filename("vaf.dat", param_pref, filename=file_name)

compute_vaf(data, out_path, use_velocities=use_vel, fft=fft)

def _write_restart(self) -> None:
"""Write restart file and (optionally) rotate files saved."""
step = self.offset + self.dyn.nsteps
Expand Down Expand Up @@ -594,6 +666,9 @@ def run(self) -> None:
self.struct.info["real_time"] = datetime.datetime.now()
self._run_dynamics()

if self.post_process_kwargs:
self._post_process()

def _run_dynamics(self) -> None:
"""Run dynamics and/or temperature ramp."""
# Store temperature for final MD
Expand Down
11 changes: 9 additions & 2 deletions janus_core/cli/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Device,
LogPath,
MinimizeKwargs,
PostProcessKwargs,
ReadKwargs,
StructPath,
Summary,
Expand Down Expand Up @@ -168,6 +169,7 @@ def md(
temp_time: Annotated[
float, Option(help="Time between heating steps, in fs.")
] = None,
post_process_kwargs: PostProcessKwargs = None,
log: LogPath = "md.log",
seed: Annotated[
Optional[int],
Expand Down Expand Up @@ -267,6 +269,8 @@ def md(
temp_time : Optional[float]
Time between heating steps, in fs. Default is None, which disables
heating.
post_process_kwargs : Optional[dict[str, Any]]
Kwargs to pass to post-processing.
log : Optional[Path]
Path to write logs to. Default is "md.log".
seed : Optional[int]
Expand All @@ -280,8 +284,10 @@ def md(
# Check options from configuration file are all valid
check_config(ctx)

[read_kwargs, calc_kwargs, minimize_kwargs] = parse_typer_dicts(
[read_kwargs, calc_kwargs, minimize_kwargs]
[read_kwargs, calc_kwargs, minimize_kwargs, post_process_kwargs] = (
parse_typer_dicts(
[read_kwargs, calc_kwargs, minimize_kwargs, post_process_kwargs]
)
)

if not ensemble in get_args(Ensembles):
Expand Down Expand Up @@ -334,6 +340,7 @@ def md(
"temp_end": temp_end,
"temp_step": temp_step,
"temp_time": temp_time,
"post_process_kwargs": post_process_kwargs,
"log_kwargs": log_kwargs,
"seed": seed,
}
Expand Down
15 changes: 15 additions & 0 deletions janus_core/cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ def __str__(self):
),
]

PostProcessKwargs = Annotated[
TyperDict,
Option(
parser=parse_dict_class,
help=(
"""
Keyword arguments to pass to post-processer. Must be passed as a dictionary
wrapped in quotes, e.g. "{'key' : value}".
"""
),
metavar="DICT",
),
]


LogPath = Annotated[Path, Option(help="Path to save logs to.")]

Summary = Annotated[
Expand Down
21 changes: 20 additions & 1 deletion janus_core/helpers/janus_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,32 @@ class ASEWriteArgs(TypedDict, total=False):


class ASEOptArgs(TypedDict, total=False):
"""Main arugments for ase optimisers."""
"""Main arguments for ase optimisers."""

restart: Optional[bool]
logfile: Optional[PathLike]
trajectory: Optional[str]


class PostProcessKwargs(TypedDict, total=False):
"""Main arguments for MD post-processing."""

# RDF
rdf_compute: bool
rdf_rmax: float
rdf_nbins: int
rdf_elements: MaybeSequence[str | int]
rdf_start: int
rdf_stop: Optional[int]
rdf_step: int
rdf_output_file: Optional[str]
# VAF
vaf_compute: bool
vaf_velocities: bool
vaf_fft: bool
vaf_output_file: Optional[PathLike]


# eos_names from ase.eos
EoSNames = Literal[
"sj",
Expand Down
Loading

0 comments on commit 5a94927

Please sign in to comment.