From 493791972bab1b4e5e5185c8c166de4dcb0ffad3 Mon Sep 17 00:00:00 2001 From: Jan Janssen Date: Mon, 11 Dec 2023 21:22:12 +0100 Subject: [PATCH 1/7] Separate module for output classes --- atomistics/calculators/lammps/calculator.py | 20 ++++--- atomistics/calculators/lammps/helpers.py | 36 +++---------- atomistics/calculators/lammps/output.py | 59 +++++++++++++++++++++ 3 files changed, 75 insertions(+), 40 deletions(-) create mode 100644 atomistics/calculators/lammps/output.py diff --git a/atomistics/calculators/lammps/calculator.py b/atomistics/calculators/lammps/calculator.py index c6829a85..7fac412e 100644 --- a/atomistics/calculators/lammps/calculator.py +++ b/atomistics/calculators/lammps/calculator.py @@ -25,7 +25,7 @@ LAMMPS_RUN, LAMMPS_MINIMIZE_VOLUME, ) -from atomistics.calculators.lammps.helpers import quantities +from atomistics.calculators.lammps.output import get_static_output, quantities_md, quantities_static if TYPE_CHECKING: from ase import Atoms @@ -113,7 +113,7 @@ def calc_static_with_lammps( structure, potential_dataframe, lmp=None, - quantities=("energy", "forces", "stress"), + quantities=quantities_static, **kwargs, ): template_str = LAMMPS_THERMO_STYLE + "\n" + LAMMPS_THERMO + "\n" + LAMMPS_RUN @@ -127,12 +127,10 @@ def calc_static_with_lammps( lmp=lmp, **kwargs, ) - interactive_getter_dict = { - "forces": lmp_instance.interactive_forces_getter, - "energy": lmp_instance.interactive_energy_pot_getter, - "stress": lmp_instance.interactive_pressures_getter, - } - result_dict = {q: interactive_getter_dict[q]() for q in quantities} + result_dict = get_static_output( + lmp_instance=lmp_instance, + quantities=quantities, + ) lammps_shutdown(lmp_instance=lmp_instance, close_instance=lmp is None) return result_dict @@ -149,7 +147,7 @@ def calc_molecular_dynamics_nvt_with_lammps( seed=4928459, dist="gaussian", lmp=None, - quantities=quantities, + quantities=quantities_md, **kwargs, ): init_str = ( @@ -206,7 +204,7 @@ def calc_molecular_dynamics_npt_with_lammps( seed=4928459, dist="gaussian", lmp=None, - quantities=quantities, + quantities=quantities_md, **kwargs, ): init_str = ( @@ -264,7 +262,7 @@ def calc_molecular_dynamics_nph_with_lammps( seed=4928459, dist="gaussian", lmp=None, - quantities=quantities, + quantities=quantities_md, **kwargs, ): init_str = ( diff --git a/atomistics/calculators/lammps/helpers.py b/atomistics/calculators/lammps/helpers.py index e5a5bb25..790bd71e 100644 --- a/atomistics/calculators/lammps/helpers.py +++ b/atomistics/calculators/lammps/helpers.py @@ -1,35 +1,11 @@ from __future__ import annotations -import dataclasses - from jinja2 import Template import numpy as np from pylammpsmpi import LammpsASELibrary from atomistics.calculators.lammps.potential import validate_potential_dataframe - - -@dataclasses.dataclass -class LammpsQuantityGetter: - positions: callable = LammpsASELibrary.interactive_positions_getter - cell: callable = LammpsASELibrary.interactive_cells_getter - forces: callable = LammpsASELibrary.interactive_forces_getter - temperature: callable = LammpsASELibrary.interactive_temperatures_getter - energy_pot: callable = LammpsASELibrary.interactive_energy_pot_getter - energy_tot: callable = LammpsASELibrary.interactive_energy_tot_getter - pressure: callable = LammpsASELibrary.interactive_pressures_getter - velocities: callable = LammpsASELibrary.interactive_velocities_getter - - @classmethod - def fields(cls): - return tuple(field.name for field in dataclasses.fields(cls)) - - def __call__(self, engine: LammpsASELibrary, quantity: str): - return getattr(self, quantity)(engine) - - -quantity_getter = LammpsQuantityGetter() -quantities = quantity_getter.fields() +from atomistics.calculators.lammps.output import get_md_output, quantities_md def lammps_run(structure, potential_dataframe, input_template=None, lmp=None, **kwargs): @@ -65,12 +41,14 @@ def lammps_calc_md_step( lmp_instance, run_str, run, - quantities=quantities, + quantities=quantities_md, ): run_str_rendered = Template(run_str).render(run=run) lmp_instance.interactive_lib_command(run_str_rendered) - # return {q: getattr(LammpsQuantityGetter, q)(lmp_instance) for q in quantities} - return {q: quantity_getter(lmp_instance, q) for q in quantities} + return get_md_output( + lmp_instance=lmp_instance, + quantities=quantities, + ) def lammps_calc_md( @@ -78,7 +56,7 @@ def lammps_calc_md( run_str, run, thermo, - quantities=quantities, + quantities=quantities_md, ): results_lst = [ lammps_calc_md_step( diff --git a/atomistics/calculators/lammps/output.py b/atomistics/calculators/lammps/output.py new file mode 100644 index 00000000..2a24952d --- /dev/null +++ b/atomistics/calculators/lammps/output.py @@ -0,0 +1,59 @@ +import dataclasses + +from pylammpsmpi import LammpsASELibrary + + +@dataclasses.dataclass +class LammpsOutput: + @classmethod + def fields(cls): + return tuple(field.name for field in dataclasses.fields(cls)) + + def __call__(self, engine: LammpsASELibrary, quantity: str): + return getattr(self, quantity)(engine) + + +@dataclasses.dataclass +class LammpsMDQuantityGetter(LammpsOutput): + positions: callable = LammpsASELibrary.interactive_positions_getter + cell: callable = LammpsASELibrary.interactive_cells_getter + forces: callable = LammpsASELibrary.interactive_forces_getter + temperature: callable = LammpsASELibrary.interactive_temperatures_getter + energy_pot: callable = LammpsASELibrary.interactive_energy_pot_getter + energy_tot: callable = LammpsASELibrary.interactive_energy_tot_getter + pressure: callable = LammpsASELibrary.interactive_pressures_getter + velocities: callable = LammpsASELibrary.interactive_velocities_getter + + +@dataclasses.dataclass +class LammpsStaticQuantityGetter(LammpsOutput): + forces: callable = LammpsASELibrary.interactive_forces_getter + energy_pot: callable = LammpsASELibrary.interactive_energy_pot_getter + stress: callable = LammpsASELibrary.interactive_pressures_getter + + + +quantity_getter_md = LammpsMDQuantityGetter() +quantities_md = quantity_getter_md.fields() +quantity_getter_static = LammpsStaticQuantityGetter() +quantities_static = quantity_getter_static.fields() + + +def get_quantity(lmp_instance, quantity_getter, quantities): + return {q: quantity_getter(lmp_instance, q) for q in quantities} + + +def get_static_output(lmp_instance, quantities=quantities_static): + return get_quantity( + lmp_instance=lmp_instance, + quantity_getter=LammpsStaticQuantityGetter, + quantities=quantities + ) + + +def get_md_output(lmp_instance, quantities=quantities_md): + return get_quantity( + lmp_instance=lmp_instance, + quantity_getter=LammpsMDQuantityGetter, + quantities=quantities + ) From 366bf7137876d3be93c03a7826e8bd9ae0150347 Mon Sep 17 00:00:00 2001 From: pyiron-runner Date: Mon, 11 Dec 2023 20:28:11 +0000 Subject: [PATCH 2/7] Format black --- atomistics/calculators/lammps/calculator.py | 6 +++++- atomistics/calculators/lammps/output.py | 5 ++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/atomistics/calculators/lammps/calculator.py b/atomistics/calculators/lammps/calculator.py index 7fac412e..a0d3b928 100644 --- a/atomistics/calculators/lammps/calculator.py +++ b/atomistics/calculators/lammps/calculator.py @@ -25,7 +25,11 @@ LAMMPS_RUN, LAMMPS_MINIMIZE_VOLUME, ) -from atomistics.calculators.lammps.output import get_static_output, quantities_md, quantities_static +from atomistics.calculators.lammps.output import ( + get_static_output, + quantities_md, + quantities_static, +) if TYPE_CHECKING: from ase import Atoms diff --git a/atomistics/calculators/lammps/output.py b/atomistics/calculators/lammps/output.py index 2a24952d..be9a6c43 100644 --- a/atomistics/calculators/lammps/output.py +++ b/atomistics/calculators/lammps/output.py @@ -32,7 +32,6 @@ class LammpsStaticQuantityGetter(LammpsOutput): stress: callable = LammpsASELibrary.interactive_pressures_getter - quantity_getter_md = LammpsMDQuantityGetter() quantities_md = quantity_getter_md.fields() quantity_getter_static = LammpsStaticQuantityGetter() @@ -47,7 +46,7 @@ def get_static_output(lmp_instance, quantities=quantities_static): return get_quantity( lmp_instance=lmp_instance, quantity_getter=LammpsStaticQuantityGetter, - quantities=quantities + quantities=quantities, ) @@ -55,5 +54,5 @@ def get_md_output(lmp_instance, quantities=quantities_md): return get_quantity( lmp_instance=lmp_instance, quantity_getter=LammpsMDQuantityGetter, - quantities=quantities + quantities=quantities, ) From 87ad74b484dbd7a9b249edae4ff9fab6623ed7eb Mon Sep 17 00:00:00 2001 From: Jan Janssen Date: Mon, 11 Dec 2023 21:54:11 +0100 Subject: [PATCH 3/7] fix getter --- atomistics/calculators/lammps/output.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/atomistics/calculators/lammps/output.py b/atomistics/calculators/lammps/output.py index be9a6c43..b89a9a2f 100644 --- a/atomistics/calculators/lammps/output.py +++ b/atomistics/calculators/lammps/output.py @@ -45,7 +45,7 @@ def get_quantity(lmp_instance, quantity_getter, quantities): def get_static_output(lmp_instance, quantities=quantities_static): return get_quantity( lmp_instance=lmp_instance, - quantity_getter=LammpsStaticQuantityGetter, + quantity_getter=quantity_getter_static, quantities=quantities, ) @@ -53,6 +53,6 @@ def get_static_output(lmp_instance, quantities=quantities_static): def get_md_output(lmp_instance, quantities=quantities_md): return get_quantity( lmp_instance=lmp_instance, - quantity_getter=LammpsMDQuantityGetter, + quantity_getter=quantity_getter_md, quantities=quantities, ) From 51551066e76f825d533004583cb773527a1f6d08 Mon Sep 17 00:00:00 2001 From: Jan Janssen Date: Mon, 11 Dec 2023 21:59:36 +0100 Subject: [PATCH 4/7] rename output --- atomistics/calculators/lammps/output.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atomistics/calculators/lammps/output.py b/atomistics/calculators/lammps/output.py index b89a9a2f..52328afb 100644 --- a/atomistics/calculators/lammps/output.py +++ b/atomistics/calculators/lammps/output.py @@ -28,7 +28,7 @@ class LammpsMDQuantityGetter(LammpsOutput): @dataclasses.dataclass class LammpsStaticQuantityGetter(LammpsOutput): forces: callable = LammpsASELibrary.interactive_forces_getter - energy_pot: callable = LammpsASELibrary.interactive_energy_pot_getter + energy: callable = LammpsASELibrary.interactive_energy_pot_getter stress: callable = LammpsASELibrary.interactive_pressures_getter From 2d9b389eb0837bd7a67298ab9091b4b07dc617b3 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Mon, 11 Dec 2023 13:53:29 -0800 Subject: [PATCH 5/7] Add a better interface than call Now we can always just use the class directly --- atomistics/calculators/lammps/output.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/atomistics/calculators/lammps/output.py b/atomistics/calculators/lammps/output.py index 52328afb..ba0e48ee 100644 --- a/atomistics/calculators/lammps/output.py +++ b/atomistics/calculators/lammps/output.py @@ -9,6 +9,10 @@ class LammpsOutput: def fields(cls): return tuple(field.name for field in dataclasses.fields(cls)) + @classmethod + def get(cls, engine: LammpsASELibrary, *quantities: str) -> dict: + return {q: getattr(cls, q)(engine) for q in quantities} + def __call__(self, engine: LammpsASELibrary, quantity: str): return getattr(self, quantity)(engine) From 24f1ae2b1c4aa7f58d568a962865a4f49d829762 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Mon, 11 Dec 2023 14:02:36 -0800 Subject: [PATCH 6/7] Purge instances of the dataclasses This also removes the misdirection layer --- atomistics/calculators/lammps/calculator.py | 19 +++++--------- atomistics/calculators/lammps/helpers.py | 11 +++----- atomistics/calculators/lammps/output.py | 29 --------------------- 3 files changed, 10 insertions(+), 49 deletions(-) diff --git a/atomistics/calculators/lammps/calculator.py b/atomistics/calculators/lammps/calculator.py index a0d3b928..5a8c0528 100644 --- a/atomistics/calculators/lammps/calculator.py +++ b/atomistics/calculators/lammps/calculator.py @@ -25,11 +25,7 @@ LAMMPS_RUN, LAMMPS_MINIMIZE_VOLUME, ) -from atomistics.calculators.lammps.output import ( - get_static_output, - quantities_md, - quantities_static, -) +from atomistics.calculators.lammps.output import LammpsMDQuantityGetter, LammpsStaticQuantityGetter if TYPE_CHECKING: from ase import Atoms @@ -117,7 +113,7 @@ def calc_static_with_lammps( structure, potential_dataframe, lmp=None, - quantities=quantities_static, + quantities=LammpsStaticQuantityGetter.fields(), **kwargs, ): template_str = LAMMPS_THERMO_STYLE + "\n" + LAMMPS_THERMO + "\n" + LAMMPS_RUN @@ -131,10 +127,7 @@ def calc_static_with_lammps( lmp=lmp, **kwargs, ) - result_dict = get_static_output( - lmp_instance=lmp_instance, - quantities=quantities, - ) + result_dict = LammpsStaticQuantityGetter.get(lmp_instance, *quantities) lammps_shutdown(lmp_instance=lmp_instance, close_instance=lmp is None) return result_dict @@ -151,7 +144,7 @@ def calc_molecular_dynamics_nvt_with_lammps( seed=4928459, dist="gaussian", lmp=None, - quantities=quantities_md, + quantities=LammpsMDQuantityGetter.fields(), **kwargs, ): init_str = ( @@ -208,7 +201,7 @@ def calc_molecular_dynamics_npt_with_lammps( seed=4928459, dist="gaussian", lmp=None, - quantities=quantities_md, + quantities=LammpsMDQuantityGetter.fields(), **kwargs, ): init_str = ( @@ -266,7 +259,7 @@ def calc_molecular_dynamics_nph_with_lammps( seed=4928459, dist="gaussian", lmp=None, - quantities=quantities_md, + quantities=LammpsMDQuantityGetter.fields(), **kwargs, ): init_str = ( diff --git a/atomistics/calculators/lammps/helpers.py b/atomistics/calculators/lammps/helpers.py index 790bd71e..3435d3e7 100644 --- a/atomistics/calculators/lammps/helpers.py +++ b/atomistics/calculators/lammps/helpers.py @@ -5,7 +5,7 @@ from pylammpsmpi import LammpsASELibrary from atomistics.calculators.lammps.potential import validate_potential_dataframe -from atomistics.calculators.lammps.output import get_md_output, quantities_md +from atomistics.calculators.lammps.output import LammpsMDQuantityGetter def lammps_run(structure, potential_dataframe, input_template=None, lmp=None, **kwargs): @@ -41,14 +41,11 @@ def lammps_calc_md_step( lmp_instance, run_str, run, - quantities=quantities_md, + quantities=LammpsMDQuantityGetter.fields(), ): run_str_rendered = Template(run_str).render(run=run) lmp_instance.interactive_lib_command(run_str_rendered) - return get_md_output( - lmp_instance=lmp_instance, - quantities=quantities, - ) + return LammpsMDQuantityGetter.get(lmp_instance, *quantities) def lammps_calc_md( @@ -56,7 +53,7 @@ def lammps_calc_md( run_str, run, thermo, - quantities=quantities_md, + quantities=LammpsMDQuantityGetter.fields(), ): results_lst = [ lammps_calc_md_step( diff --git a/atomistics/calculators/lammps/output.py b/atomistics/calculators/lammps/output.py index ba0e48ee..8b559fa5 100644 --- a/atomistics/calculators/lammps/output.py +++ b/atomistics/calculators/lammps/output.py @@ -13,9 +13,6 @@ def fields(cls): def get(cls, engine: LammpsASELibrary, *quantities: str) -> dict: return {q: getattr(cls, q)(engine) for q in quantities} - def __call__(self, engine: LammpsASELibrary, quantity: str): - return getattr(self, quantity)(engine) - @dataclasses.dataclass class LammpsMDQuantityGetter(LammpsOutput): @@ -34,29 +31,3 @@ class LammpsStaticQuantityGetter(LammpsOutput): forces: callable = LammpsASELibrary.interactive_forces_getter energy: callable = LammpsASELibrary.interactive_energy_pot_getter stress: callable = LammpsASELibrary.interactive_pressures_getter - - -quantity_getter_md = LammpsMDQuantityGetter() -quantities_md = quantity_getter_md.fields() -quantity_getter_static = LammpsStaticQuantityGetter() -quantities_static = quantity_getter_static.fields() - - -def get_quantity(lmp_instance, quantity_getter, quantities): - return {q: quantity_getter(lmp_instance, q) for q in quantities} - - -def get_static_output(lmp_instance, quantities=quantities_static): - return get_quantity( - lmp_instance=lmp_instance, - quantity_getter=quantity_getter_static, - quantities=quantities, - ) - - -def get_md_output(lmp_instance, quantities=quantities_md): - return get_quantity( - lmp_instance=lmp_instance, - quantity_getter=quantity_getter_md, - quantities=quantities, - ) From 9c328b04bf621426c29d7e5269f8996ad58b95b0 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Mon, 11 Dec 2023 14:04:17 -0800 Subject: [PATCH 7/7] Give the classes shorter names --- atomistics/calculators/lammps/calculator.py | 12 ++++++------ atomistics/calculators/lammps/helpers.py | 8 ++++---- atomistics/calculators/lammps/output.py | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/atomistics/calculators/lammps/calculator.py b/atomistics/calculators/lammps/calculator.py index 5a8c0528..6fee3952 100644 --- a/atomistics/calculators/lammps/calculator.py +++ b/atomistics/calculators/lammps/calculator.py @@ -25,7 +25,7 @@ LAMMPS_RUN, LAMMPS_MINIMIZE_VOLUME, ) -from atomistics.calculators.lammps.output import LammpsMDQuantityGetter, LammpsStaticQuantityGetter +from atomistics.calculators.lammps.output import LammpsMDOutput, LammpsStaticOutput if TYPE_CHECKING: from ase import Atoms @@ -113,7 +113,7 @@ def calc_static_with_lammps( structure, potential_dataframe, lmp=None, - quantities=LammpsStaticQuantityGetter.fields(), + quantities=LammpsStaticOutput.fields(), **kwargs, ): template_str = LAMMPS_THERMO_STYLE + "\n" + LAMMPS_THERMO + "\n" + LAMMPS_RUN @@ -127,7 +127,7 @@ def calc_static_with_lammps( lmp=lmp, **kwargs, ) - result_dict = LammpsStaticQuantityGetter.get(lmp_instance, *quantities) + result_dict = LammpsStaticOutput.get(lmp_instance, *quantities) lammps_shutdown(lmp_instance=lmp_instance, close_instance=lmp is None) return result_dict @@ -144,7 +144,7 @@ def calc_molecular_dynamics_nvt_with_lammps( seed=4928459, dist="gaussian", lmp=None, - quantities=LammpsMDQuantityGetter.fields(), + quantities=LammpsMDOutput.fields(), **kwargs, ): init_str = ( @@ -201,7 +201,7 @@ def calc_molecular_dynamics_npt_with_lammps( seed=4928459, dist="gaussian", lmp=None, - quantities=LammpsMDQuantityGetter.fields(), + quantities=LammpsMDOutput.fields(), **kwargs, ): init_str = ( @@ -259,7 +259,7 @@ def calc_molecular_dynamics_nph_with_lammps( seed=4928459, dist="gaussian", lmp=None, - quantities=LammpsMDQuantityGetter.fields(), + quantities=LammpsMDOutput.fields(), **kwargs, ): init_str = ( diff --git a/atomistics/calculators/lammps/helpers.py b/atomistics/calculators/lammps/helpers.py index 3435d3e7..0e763f71 100644 --- a/atomistics/calculators/lammps/helpers.py +++ b/atomistics/calculators/lammps/helpers.py @@ -5,7 +5,7 @@ from pylammpsmpi import LammpsASELibrary from atomistics.calculators.lammps.potential import validate_potential_dataframe -from atomistics.calculators.lammps.output import LammpsMDQuantityGetter +from atomistics.calculators.lammps.output import LammpsMDOutput def lammps_run(structure, potential_dataframe, input_template=None, lmp=None, **kwargs): @@ -41,11 +41,11 @@ def lammps_calc_md_step( lmp_instance, run_str, run, - quantities=LammpsMDQuantityGetter.fields(), + quantities=LammpsMDOutput.fields(), ): run_str_rendered = Template(run_str).render(run=run) lmp_instance.interactive_lib_command(run_str_rendered) - return LammpsMDQuantityGetter.get(lmp_instance, *quantities) + return LammpsMDOutput.get(lmp_instance, *quantities) def lammps_calc_md( @@ -53,7 +53,7 @@ def lammps_calc_md( run_str, run, thermo, - quantities=LammpsMDQuantityGetter.fields(), + quantities=LammpsMDOutput.fields(), ): results_lst = [ lammps_calc_md_step( diff --git a/atomistics/calculators/lammps/output.py b/atomistics/calculators/lammps/output.py index 8b559fa5..3313bb1f 100644 --- a/atomistics/calculators/lammps/output.py +++ b/atomistics/calculators/lammps/output.py @@ -15,7 +15,7 @@ def get(cls, engine: LammpsASELibrary, *quantities: str) -> dict: @dataclasses.dataclass -class LammpsMDQuantityGetter(LammpsOutput): +class LammpsMDOutput(LammpsOutput): positions: callable = LammpsASELibrary.interactive_positions_getter cell: callable = LammpsASELibrary.interactive_cells_getter forces: callable = LammpsASELibrary.interactive_forces_getter @@ -27,7 +27,7 @@ class LammpsMDQuantityGetter(LammpsOutput): @dataclasses.dataclass -class LammpsStaticQuantityGetter(LammpsOutput): +class LammpsStaticOutput(LammpsOutput): forces: callable = LammpsASELibrary.interactive_forces_getter energy: callable = LammpsASELibrary.interactive_energy_pot_getter stress: callable = LammpsASELibrary.interactive_pressures_getter