Skip to content

Commit

Permalink
Merge pull request #139 from pyiron/connection
Browse files Browse the repository at this point in the history
Abstraction for connection interface
  • Loading branch information
jan-janssen authored Jul 30, 2023
2 parents 8cf7c24 + ea81df7 commit 29734b5
Show file tree
Hide file tree
Showing 13 changed files with 617 additions and 380 deletions.
9 changes: 5 additions & 4 deletions pympipool/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from pympipool.shared.communication import (
SocketInterface,
connect_to_socket_interface,
send_result,
close_connection,
receive_instruction,
interface_connect,
interface_bootup,
interface_send,
interface_shutdown,
interface_receive,
)
from pympipool.interfaces.taskbroker import HPCExecutor
from pympipool.interfaces.taskexecutor import Executor
Expand Down
20 changes: 10 additions & 10 deletions pympipool/backend/mpiexec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import cloudpickle

from pympipool.shared.communication import (
connect_to_socket_interface,
send_result,
close_connection,
receive_instruction,
interface_connect,
interface_send,
interface_shutdown,
interface_receive,
)
from pympipool.shared.backend import call_funct, parse_arguments

Expand All @@ -26,7 +26,7 @@ def main():

argument_dict = parse_arguments(argument_lst=sys.argv)
if mpi_rank_zero:
context, socket = connect_to_socket_interface(
context, socket = interface_connect(
host=argument_dict["host"], port=argument_dict["zmqport"]
)
else:
Expand All @@ -43,16 +43,16 @@ def main():
while True:
# Read from socket
if mpi_rank_zero:
input_dict = receive_instruction(socket=socket)
input_dict = interface_receive(socket=socket)
else:
input_dict = None
input_dict = MPI.COMM_WORLD.bcast(input_dict, root=0)

# Parse input
if "shutdown" in input_dict.keys() and input_dict["shutdown"]:
if mpi_rank_zero:
send_result(socket=socket, result_dict={"result": True})
close_connection(socket=socket, context=context)
interface_send(socket=socket, result_dict={"result": True})
interface_shutdown(socket=socket, context=context)
break
elif (
"fn" in input_dict.keys()
Expand All @@ -69,14 +69,14 @@ def main():
output_reply = output
except Exception as error:
if mpi_rank_zero:
send_result(
interface_send(
socket=socket,
result_dict={"error": error, "error_type": str(type(error))},
)
else:
# Send output
if mpi_rank_zero:
send_result(socket=socket, result_dict={"result": output_reply})
interface_send(socket=socket, result_dict={"result": output_reply})
elif (
"init" in input_dict.keys()
and input_dict["init"]
Expand Down
20 changes: 10 additions & 10 deletions pympipool/legacy/backend/mpipool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import cloudpickle

from pympipool.shared.communication import (
connect_to_socket_interface,
send_result,
close_connection,
receive_instruction,
interface_connect,
interface_send,
interface_shutdown,
interface_receive,
)
from pympipool.legacy.shared.backend import parse_socket_communication, parse_arguments

Expand Down Expand Up @@ -36,27 +36,27 @@ def main():
path=sys.path, # required for flux interface - otherwise the current path is not included in the python path
) as executor:
if executor is not None:
context, socket = connect_to_socket_interface(
context, socket = interface_connect(
host=argument_dict["host"], port=argument_dict["zmqport"]
)
while True:
output = parse_socket_communication(
executor=executor,
input_dict=receive_instruction(socket=socket),
input_dict=interface_receive(socket=socket),
future_dict=future_dict,
cores_per_task=int(argument_dict["cores_per_task"]),
)
if "exit" in output.keys() and output["exit"]:
if "result" in output.keys():
send_result(
interface_send(
socket=socket, result_dict={"result": output["result"]}
)
else:
send_result(socket=socket, result_dict={"result": True})
close_connection(socket=socket, context=context)
interface_send(socket=socket, result_dict={"result": True})
interface_shutdown(socket=socket, context=context)
break
elif isinstance(output, dict):
send_result(socket=socket, result_dict=output)
interface_send(socket=socket, result_dict=output)


if __name__ == "__main__":
Expand Down
61 changes: 25 additions & 36 deletions pympipool/legacy/interfaces/pool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC

from pympipool.shared.communication import SocketInterface
from pympipool.shared.communication import interface_bootup
from pympipool.shared.taskexecutor import cloudpickle_register
from pympipool.legacy.shared.interface import get_parallel_subprocess_command
from pympipool.legacy.shared.interface import get_pool_command


class PoolBase(ABC):
Expand All @@ -11,11 +11,9 @@ class PoolBase(ABC):
alone. Rather it implements the __enter__(), __exit__() and shutdown() function shared between the derived classes.
"""

def __init__(self, queue_adapter=None, queue_adapter_kwargs=None):
def __init__(self):
self._future_dict = {}
self._interface = SocketInterface(
queue_adapter=queue_adapter, queue_adapter_kwargs=queue_adapter_kwargs
)
self._interface = None
cloudpickle_register(ind=3)

def __enter__(self):
Expand Down Expand Up @@ -71,23 +69,17 @@ def __init__(
queue_adapter=None,
queue_adapter_kwargs=None,
):
super().__init__(
queue_adapter=queue_adapter, queue_adapter_kwargs=queue_adapter_kwargs
)
self._interface.bootup(
command_lst=get_parallel_subprocess_command(
port_selected=self._interface.bind_to_random_port(),
cores=max_workers,
cores_per_task=1,
gpus_per_task=gpus_per_task,
oversubscribe=oversubscribe,
enable_flux_backend=enable_flux_backend,
enable_slurm_backend=enable_slurm_backend,
enable_mpi4py_backend=True,
enable_multi_host=queue_adapter is not None,
),
super().__init__()
self._interface = interface_bootup(
command_lst=get_pool_command(cores_total=max_workers, ranks_per_task=1)[0],
cwd=cwd,
cores=max_workers,
gpus_per_core=gpus_per_task,
oversubscribe=oversubscribe,
enable_flux_backend=enable_flux_backend,
enable_slurm_backend=enable_slurm_backend,
queue_adapter=queue_adapter,
queue_adapter_kwargs=queue_adapter_kwargs,
)

def map(self, func, iterable, chunksize=None):
Expand Down Expand Up @@ -178,23 +170,20 @@ def __init__(
queue_adapter=None,
queue_adapter_kwargs=None,
):
super().__init__(
queue_adapter=queue_adapter, queue_adapter_kwargs=queue_adapter_kwargs
super().__init__()
command_lst, cores = get_pool_command(
cores_total=max_ranks, ranks_per_task=ranks_per_task
)
self._interface.bootup(
command_lst=get_parallel_subprocess_command(
port_selected=self._interface.bind_to_random_port(),
cores=max_ranks,
cores_per_task=ranks_per_task,
gpus_per_task=gpus_per_task,
oversubscribe=oversubscribe,
enable_flux_backend=False,
enable_slurm_backend=False,
enable_mpi4py_backend=True,
enable_multi_host=queue_adapter is not None,
),
self._interface = interface_bootup(
command_lst=command_lst,
cwd=cwd,
cores=max_ranks,
cores=cores,
gpus_per_core=gpus_per_task,
oversubscribe=oversubscribe,
enable_flux_backend=False,
enable_slurm_backend=False,
queue_adapter=queue_adapter,
queue_adapter_kwargs=queue_adapter_kwargs,
)

def map(self, func, iterable, chunksize=None):
Expand Down
Loading

0 comments on commit 29734b5

Please sign in to comment.