Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redundant methods in P2PBarrierTask #8924

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import dask.config
from dask._task_spec import Task, _inline_recursively
from dask.core import flatten
from dask.sizeof import sizeof
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta

Expand Down Expand Up @@ -601,41 +600,17 @@ def __init__(
super().__init__(key, func, *args, **kwargs)

def copy(self) -> P2PBarrierTask:
self.unpack()
assert self.func is not None
return P2PBarrierTask(
self.key, self.func, *self.args, spec=self.spec, **self.kwargs
)

def __sizeof__(self) -> int:
return super().__sizeof__() + sizeof(self.spec)

def __repr__(self) -> str:
return f"P2PBarrierTask({self.key!r})"

def inline(self, dsk: dict[Key, Any]) -> P2PBarrierTask:
self.unpack()
new_args = _inline_recursively(self.args, dsk)
new_kwargs = _inline_recursively(self.kwargs, dsk)
assert self.func is not None
return P2PBarrierTask(
self.key, self.func, *new_args, spec=self.spec, **new_kwargs
)

def __getstate__(self) -> dict[str, Any]:
state = super().__getstate__()
state["spec"] = self.spec
return state

def __setstate__(self, state: dict[str, Any]) -> None:
super().__setstate__(state)
self.spec = state["spec"]

def __eq__(self, value: object) -> bool:
if not isinstance(value, P2PBarrierTask):
return False
if not super().__eq__(value):
return False
if self.spec != value.spec:
return False
return True
35 changes: 16 additions & 19 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import dask
import dask.bag as db
from dask import delayed
from dask._task_spec import no_function_cache
from dask.optimization import SubgraphCallable
from dask.tokenize import tokenize
from dask.utils import get_default_shuffle_method, parse_timedelta, tmpfile
Expand Down Expand Up @@ -4934,29 +4933,27 @@ def __setstate__(self, state):

@gen_cluster(client=True)
async def test_robust_undeserializable_function(c, s, a, b, monkeypatch):
with no_function_cache():

class Foo:
def __getstate__(self):
return 1
class Foo:
def __getstate__(self):
return 1

def __setstate__(self, state):
raise MyException("hello")
def __setstate__(self, state):
raise MyException("hello")

def __call__(self, *args):
return 1
def __call__(self, *args):
return 1

future = c.submit(Foo(), 1)
await wait(future)
assert future.status == "error"
with raises_with_cause(RuntimeError, "deserialization", MyException, "hello"):
await future
future = c.submit(Foo(), 1)
await wait(future)
assert future.status == "error"
with raises_with_cause(RuntimeError, "deserialization", MyException, "hello"):
await future

futures = c.map(inc, range(10))
results = await c.gather(futures)
futures = c.map(inc, range(10))
results = await c.gather(futures)

assert results == list(map(inc, range(10)))
assert a.data and b.data
assert results == list(map(inc, range(10)))
assert a.data and b.data


@gen_cluster(client=True)
Expand Down
Loading