diff --git a/dask_mpi/core.py b/dask_mpi/core.py index e133a42..47994a4 100644 --- a/dask_mpi/core.py +++ b/dask_mpi/core.py @@ -10,6 +10,7 @@ def initialize( + comm, interface=None, nthreads=1, local_directory=None, @@ -20,7 +21,6 @@ def initialize( protocol=None, worker_class="distributed.Worker", worker_options=None, - comm=None, exit=True, ): """ @@ -38,6 +38,8 @@ def initialize( Parameters ---------- + comm: mpi4py.MPI.Intracomm + Optional MPI communicator to use instead of COMM_WORLD interface : str Network interface like 'eth0' or 'ib0' nthreads : int @@ -59,8 +61,6 @@ def initialize( Class to use when creating workers worker_options : dict Options to pass to workers - comm: mpi4py.MPI.Intracomm - Optional MPI communicator to use instead of COMM_WORLD exit: bool Whether to call sys.exit on the workers and schedulers when the event loop completes. @@ -71,10 +71,9 @@ def initialize( Only returned if exit=False. Inidcates whether this rank should continue to run client code (True), or if it acts as a scheduler or worker (False). """ - if comm is None: - from mpi4py import MPI - - comm = MPI.COMM_WORLD + assert ( + comm is not None + ), "MPI Comm World needs to be created before import distributed." world_size = comm.Get_size() if world_size < 3: