Skip to content

Commit

Permalink
FIX-modin-project#7346: Handle execution on Dask workers to avoid cre…
Browse files Browse the repository at this point in the history
…ating conflicting Clients

Signed-off-by: Michael Akerman <[email protected]>
  • Loading branch information
data-makerman committed Jul 19, 2024
1 parent 4815bc3 commit b5f50b9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
25 changes: 22 additions & 3 deletions modin/core/execution/dask/common/engine_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,25 @@
from dask.distributed import wait
from distributed import Future
from distributed.client import default_client
from distributed.worker import get_worker


def get_dask_client():
"""
Get the Dask client, reusing the worker's client if execution is on a Dask worker.
Returns
-------
distributed.Client
The Dask client.
"""
try:
client = default_client()
except ValueError:
# We ought to be in a worker process
worker = get_worker()
client = worker.client
return client


def _deploy_dask_func(func, *args, return_pandas_df=None, **kwargs): # pragma: no cover
Expand Down Expand Up @@ -83,7 +102,7 @@ def deploy(
list
The result of ``func`` split into parts in accordance with ``num_returns``.
"""
client = default_client()
client = get_dask_client()
args = [] if f_args is None else f_args
kwargs = {} if f_kwargs is None else f_kwargs
if callable(func):
Expand Down Expand Up @@ -137,7 +156,7 @@ def materialize(cls, future):
Any
An object(s) from the distributed memory.
"""
client = default_client()
client = get_dask_client()
return client.gather(future)

@classmethod
Expand All @@ -164,7 +183,7 @@ def put(cls, data, **kwargs):
# {'sep': <Future: finished, type: pandas._libs.lib._NoDefault, key: sep>, \
# 'delimiter': <Future: finished, type: NoneType, key: delimiter> ...
data = UserDict(data)
client = default_client()
client = get_dask_client()
return client.scatter(data, **kwargs)

@classmethod
Expand Down
11 changes: 11 additions & 0 deletions modin/core/execution/dask/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@
def initialize_dask():
"""Initialize Dask environment."""
from distributed.client import default_client
from distributed.worker import get_worker

try:
# Check if running within a Dask worker process
get_worker()
# If the above line does not raise an error, we are in a worker process
# and should not create a new client
return
except ValueError:
# Not in a Dask worker, proceed to check for or create a client
pass

try:
client = default_client()
Expand Down

0 comments on commit b5f50b9

Please sign in to comment.