Skip to content

Commit

Permalink
Add "awatiable_builder" decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Aug 18, 2024
1 parent 21eb8d0 commit 2774f56
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 6 deletions.
17 changes: 17 additions & 0 deletions aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,23 @@ def decorator(func):

return decorator

@staticmethod
@nonfunctional_usage
def awaitable_builder(**kwargs: Any) -> Callable:
def decorator(func):
# Then, apply task decorator
task_decorated = build_task_from_callable(
func,
inputs=kwargs.get("inputs", []),
outputs=kwargs.get("outputs", []),
)
task_decorated.node_type = "awaitable_builder"
func.identifier = "awaitable_builder"
func.task = func.node = task_decorated
return func

return decorator

# Making decorator_task accessible as 'task'
task = decorator_task

Expand Down
76 changes: 76 additions & 0 deletions aiida_workgraph/engine/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
from __future__ import annotations
from aiida_workgraph.orm.serializer import serialize_to_aiida_nodes
from aiida import orm
from aiida.common.extendeddicts import AttributeDict
from aiida.engine.utils import instantiate_process, prepare_inputs
from aiida.manage import manager
from aiida.engine import run_get_node
from aiida.common import InvalidOperation
from aiida.common.log import AIIDA_LOGGER
from aiida.engine.processes import Process, ProcessBuilder
from aiida.orm import ProcessNode
import typing as t
import time

TYPE_SUBMIT_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder]
LOGGER = AIIDA_LOGGER.getChild("engine.launch")


def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple:
Expand Down Expand Up @@ -142,3 +155,66 @@ def prepare_for_shell_task(task: dict, kwargs: dict) -> dict:
"metadata": metadata or {},
}
return inputs


# modified from aiida.engine.submit
# do not check the scope of the process
def submit(
process: TYPE_SUBMIT_PROCESS,
inputs: dict[str, t.Any] | None = None,
*,
wait: bool = False,
wait_interval: int = 5,
**kwargs: t.Any,
) -> ProcessNode:
"""Submit the process with the supplied inputs to the daemon immediately returning control to the interpreter.
.. warning: this should not be used within another process. Instead, there one should use the ``submit`` method of
the wrapping process itself, i.e. use ``self.submit``.
.. warning: submission of processes requires ``store_provenance=True``.
:param process: the process class, instance or builder to submit
:param inputs: the input dictionary to be passed to the process
:param wait: when set to ``True``, the submission will be blocking and wait for the process to complete at which
point the function returns the calculation node.
:param wait_interval: the number of seconds to wait between checking the state of the process when ``wait=True``.
:param kwargs: inputs to be passed to the process. This is an alternative to the positional ``inputs`` argument.
:return: the calculation node of the process
"""
inputs = prepare_inputs(inputs, **kwargs)

runner = manager.get_manager().get_runner()
assert runner.persister is not None, "runner does not have a persister"
assert runner.controller is not None, "runner does not have a controller"

process_inited = instantiate_process(runner, process, **inputs)

# If a dry run is requested, simply forward to `run`, because it is not compatible with `submit`. We choose for this
# instead of raising, because in this way the user does not have to change the launcher when testing. The same goes
# for if `remote_folder` is present in the inputs, which means we are importing an already completed calculation.
if process_inited.metadata.get("dry_run", False) or "remote_folder" in inputs:
_, node = run_get_node(process_inited)
return node

if not process_inited.metadata.store_provenance:
raise InvalidOperation("cannot submit a process with `store_provenance=False`")

runner.persister.save_checkpoint(process_inited)
process_inited.close()

# Do not wait for the future's result, because in the case of a single worker this would cock-block itself
runner.controller.continue_process(process_inited.pid, nowait=False, no_reply=True)
node = process_inited.node

if not wait:
return node

while not node.is_terminated:
LOGGER.report(
f"Process<{node.pk}> has not yet terminated, current state is `{node.process_state}`. "
f"Waiting for {wait_interval} seconds."
)
time.sleep(wait_interval)

return node
40 changes: 34 additions & 6 deletions aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
__all__ = "WorkGraph"


MAX_NUMBER_AWAITABLES_MSG = "The maximum number of subprocesses has been reached: {}. Cannot launch the job: {}."
MAX_NUMBER_AWAITABLES_MSG = "The maximum number of subprocesses has been reached: {}.\
Waiting for other jobs to finish before launching the {}."


@auto_persist("_awaitables")
Expand Down Expand Up @@ -289,16 +290,16 @@ def _do_step(self) -> t.Any:
else:
finished, result = self.is_workgraph_finished()

if self._awaitables:
return Wait(self._do_step, "Waiting before next step")

# If the workgraph is finished or the result is an ExitCode, we exit by returning
if finished:
if isinstance(result, ExitCode):
return result
else:
return self.finalize()

if self._awaitables:
return Wait(self._do_step, "Waiting before next step")

return Continue(self._do_step)

def _store_nodes(self, data: t.Any) -> None:
Expand Down Expand Up @@ -390,7 +391,10 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None:

# node finished, update the task state and result
# udpate the task state
self.update_task_state(awaitable.key)
if awaitable.key in self.ctx._tasks:
self.update_task_state(awaitable.key)
else:
self.report(f"Awaitable {awaitable.key} finished.")
# try to resume the workgraph, if the workgraph is already resumed
# by other awaitable, this will not work
try:
Expand Down Expand Up @@ -870,9 +874,10 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None
"WORKGRAPH",
"PYTHONJOB",
"SHELLJOB",
"AWAITABLE_BUILDER",
]:
if len(self._awaitables) >= self.ctx._max_number_awaitables:
print(
self.report(
MAX_NUMBER_AWAITABLES_MSG.format(
self.ctx._max_number_awaitables, name
)
Expand Down Expand Up @@ -1066,6 +1071,29 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None
self.set_task_state_info(name, "state", "FINISHED")
self.update_parent_task_state(name)
self.continue_workgraph()
elif task["metadata"]["node_type"].upper() in ["AWAITABLE_BUILDER"]:
# create the awaitable
for key in self.ctx._tasks[name]["metadata"]["args"]:
kwargs.pop(key, None)
results = self.run_executor(
executor, args, kwargs, var_args, var_kwargs
)
if not isinstance(results, dict):
self.report("The results of the awaitable builder must be a dict.")
for key, value in results.items():
if not isinstance(value, ProcessNode):
self.report(
f"The value of key {key} is not an instance of ProcessNode."
)
self.set_task_state_info(name, "state", "Failed")
self.set_task_state_info(name, "state", "Failed")
self.report(f"Task: {name} failed.")
else:
self.set_task_state_info(name, "state", "FINISHED")
self.to_context(**results)
self.report(f"Task: {name} finished.")
self.update_parent_task_state(name)
self.continue_workgraph()
elif task["metadata"]["node_type"].upper() in ["NORMAL"]:
# normal function does not have a process
if "context" in task["metadata"]["kwargs"]:
Expand Down

0 comments on commit 2774f56

Please sign in to comment.