diff --git a/modin/core/execution/dask/common/engine_wrapper.py b/modin/core/execution/dask/common/engine_wrapper.py index c79f83e7d68..2ae7afb783b 100644 --- a/modin/core/execution/dask/common/engine_wrapper.py +++ b/modin/core/execution/dask/common/engine_wrapper.py @@ -19,6 +19,25 @@ from dask.distributed import wait from distributed import Future from distributed.client import default_client +from distributed.worker import get_worker + + +def get_dask_client(): + """ + Get the Dask client, reusing the worker's client if execution is on a Dask worker. + + Returns + ------- + distributed.Client + The Dask client. + """ + try: + client = default_client() + except ValueError: + # We ought to be in a worker process + worker = get_worker() + client = worker.client + return client def _deploy_dask_func(func, *args, return_pandas_df=None, **kwargs): # pragma: no cover @@ -83,7 +102,7 @@ def deploy( list The result of ``func`` split into parts in accordance with ``num_returns``. """ - client = default_client() + client = get_dask_client() args = [] if f_args is None else f_args kwargs = {} if f_kwargs is None else f_kwargs if callable(func): @@ -137,7 +156,7 @@ def materialize(cls, future): Any An object(s) from the distributed memory. """ - client = default_client() + client = get_dask_client() return client.gather(future) @classmethod @@ -164,7 +183,7 @@ def put(cls, data, **kwargs): # {'sep': , \ # 'delimiter': ... data = UserDict(data) - client = default_client() + client = get_dask_client() return client.scatter(data, **kwargs) @classmethod diff --git a/modin/core/execution/dask/common/utils.py b/modin/core/execution/dask/common/utils.py index 067a94fcdf0..1d8d9425733 100644 --- a/modin/core/execution/dask/common/utils.py +++ b/modin/core/execution/dask/common/utils.py @@ -30,6 +30,17 @@ def initialize_dask(): """Initialize Dask environment.""" from distributed.client import default_client + from distributed.worker import get_worker + + try: + # Check if running within a Dask worker process + get_worker() + # If the above line does not raise an error, we are in a worker process + # and should not create a new client + return + except ValueError: + # Not in a Dask worker, proceed to check for or create a client + pass try: client = default_client()