diff --git a/pyiron_contrib/workflow/node_library/atomistics.py b/pyiron_contrib/workflow/node_library/atomistics.py index 12663edf1..781d039b4 100644 --- a/pyiron_contrib/workflow/node_library/atomistics.py +++ b/pyiron_contrib/workflow/node_library/atomistics.py @@ -24,42 +24,27 @@ def lammps(structure: Optional[Atoms] = None) -> LammpsJob: return job -@node( - "cells", - "displacements", - "energy_pot", - "energy_tot", - "force_max", - "forces", - "indices", - "positions", - "pressures", - "steps", - "temperature", - "total_displacements", - "unwrapped_positions", - "volume", -) -def calc_md( - job: AtomisticGenericJob, - n_ionic_steps: int = 1000, - n_print: int = 100, - temperature: int | float = 300.0, - pressure: float - | tuple[float, float, float] - | tuple[float, float, float, float, float, float] - | None = None, -): +def _run_and_remove_job(job, modifier: Optional[callable] = None, **modifier_kwargs): + """ + Extracts the commonalities for all the "calc" methods for running a Lammps engine. + Will need to be extended/updated once we support other engines so that more output + can be parsed. Output may wind up more concretely packaged, e.g. as `CalcOutput` or + `MDOutput`, etc., ala Joerg's suggestion later, so for the time being we don't put + too much effort into this. + + Warning: + Jobs are created in a dummy project with a dummy name and are all removed at the + end; this works fine for serial workflows, but will need to be revisited -- + probably with naming based on the parantage of node/workflow labels -- once + other non-serial execution is introduced. + """ job_name = "JUSTAJOBNAME" pr = Project("WORKFLOWNAMEPROJECT") job = job.copy_to(project=pr, new_job_name=job_name, delete_existing_job=True) - job.calc_md( - n_ionic_steps=n_ionic_steps, - n_print=n_print, - temperature=temperature, - pressure=pressure, - ) + if modifier is not None: + job = modifier(job, **modifier_kwargs) job.run() + cells = job.output.cells displacements = job.output.displacements energy_pot = job.output.energy_pot @@ -74,8 +59,10 @@ def calc_md( total_displacements = job.output.total_displacements unwrapped_positions = job.output.unwrapped_positions volume = job.output.volume + job.remove() pr.remove(enable=True) + return ( cells, displacements, @@ -94,8 +81,76 @@ def calc_md( ) +@node( + "cells", + "displacements", + "energy_pot", + "energy_tot", + "force_max", + "forces", + "indices", + "positions", + "pressures", + "steps", + "temperature", + "total_displacements", + "unwrapped_positions", + "volume", +) +def calc_static( + job: AtomisticGenericJob, +): + return _run_and_remove_job(job=job) + + +@node( + "cells", + "displacements", + "energy_pot", + "energy_tot", + "force_max", + "forces", + "indices", + "positions", + "pressures", + "steps", + "temperature", + "total_displacements", + "unwrapped_positions", + "volume", +) +def calc_md( + job: AtomisticGenericJob, + n_ionic_steps: int = 1000, + n_print: int = 100, + temperature: int | float = 300.0, + pressure: float + | tuple[float, float, float] + | tuple[float, float, float, float, float, float] + | None = None, +): + def calc_md(job, n_ionic_steps, n_print, temperature, pressure): + job.calc_md( + n_ionic_steps=n_ionic_steps, + n_print=n_print, + temperature=temperature, + pressure=pressure, + ) + return job + + return _run_and_remove_job( + job=job, + modifier=calc_md, + n_ionic_steps=n_ionic_steps, + n_print=n_print, + temperature=temperature, + pressure=pressure, + ) + + nodes = [ bulk_structure, calc_md, + calc_static, lammps, ] diff --git a/pyiron_contrib/workflow/workflow.py b/pyiron_contrib/workflow/workflow.py index ba460416f..b9af9bf63 100644 --- a/pyiron_contrib/workflow/workflow.py +++ b/pyiron_contrib/workflow/workflow.py @@ -31,7 +31,7 @@ def __getattribute__(self, key): return value def __call__(self, node: Node): - self._workflow.add_node(node) + return self._workflow.add_node(node) class _NodeDecoratorAccess: @@ -180,6 +180,7 @@ def add_node(self, node: Node, label: str = None) -> None: self.nodes[label] = node node.label = label node.workflow = self + return node def _ensure_node_belongs_to_at_most_this_workflow(self, node: Node, label: str): if (