Skip to content

Commit

Permalink
Defer import of dask.array
Browse files Browse the repository at this point in the history
  • Loading branch information
ericpre committed Jun 3, 2023
1 parent 3fb63af commit 6d4a2a2
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 17 deletions.
8 changes: 0 additions & 8 deletions pint/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,6 @@ def _to_magnitude(value, force_ndarray=False, force_ndarray_like=False):
# Define location of pint.Quantity in NEP-13 type cast hierarchy by defining upcast
# types using guarded imports

try:
from dask import array as dask_array
from dask.base import compute, persist, visualize
except ImportError:
compute, persist, visualize = None, None, None
dask_array = None


# TODO: merge with upcast_type_map

#: List upcast type names
Expand Down
32 changes: 25 additions & 7 deletions pint/facets/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Generic, Any
import functools

from ...compat import compute, dask_array, persist, visualize, TypeAlias
from ...compat import TypeAlias
from ..plain import (
GenericPlainRegistry,
PlainQuantity,
Expand All @@ -25,14 +25,20 @@
)


def is_dask_array(obj):
return type(obj).__name__ == "Array" and "dask" == type(obj).__module__[:4]


def check_dask_array(f):
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
if isinstance(self._magnitude, dask_array.Array):
if is_dask_array(self._magnitude):
return f(self, *args, **kwargs)
else:
msg = "Method {} only implemented for objects of {}, not {}".format(
f.__name__, dask_array.Array, self._magnitude.__class__
msg = (
"Method {} only implemented for objects of dask array, not {}.".format(
f.__name__, self._magnitude.__class__.__name__
)
)
raise AttributeError(msg)

Expand All @@ -42,7 +48,9 @@ def wrapper(self, *args, **kwargs):
class DaskQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
# Dask.array.Array ducking
def __dask_graph__(self):
if isinstance(self._magnitude, dask_array.Array):
import dask.array as da

if isinstance(self._magnitude, da.Array):
return self._magnitude.__dask_graph__()

return None
Expand All @@ -57,11 +65,15 @@ def __dask_tokenize__(self):

@property
def __dask_optimize__(self):
return dask_array.Array.__dask_optimize__
import dask.array as da

return da.Array.__dask_optimize__

@property
def __dask_scheduler__(self):
return dask_array.Array.__dask_scheduler__
import dask.array as da

return da.Array.__dask_scheduler__

def __dask_postcompute__(self):
func, args = self._magnitude.__dask_postcompute__()
Expand Down Expand Up @@ -89,6 +101,8 @@ def compute(self, **kwargs):
pint.PlainQuantity
A pint.PlainQuantity wrapped numpy array.
"""
from dask.base import compute

(result,) = compute(self, **kwargs)
return result

Expand All @@ -106,6 +120,8 @@ def persist(self, **kwargs):
pint.PlainQuantity
A pint.PlainQuantity wrapped Dask array.
"""
from dask.base import persist

(result,) = persist(self, **kwargs)
return result

Expand All @@ -124,6 +140,8 @@ def visualize(self, **kwargs):
-------
"""
from dask.base import visualize

visualize(self, **kwargs)


Expand Down
3 changes: 1 addition & 2 deletions pint/testsuite/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ def test_exception_method_not_implemented(local_registry, numpy_array, method):

exctruth = (
f"Method {method} only implemented for objects of"
" <class 'dask.array.core.Array'>, not"
" <class 'numpy.ndarray'>"
" dask array, not ndarray."
)
with pytest.raises(AttributeError, match=exctruth):
obj_method = getattr(q, method)
Expand Down

0 comments on commit 6d4a2a2

Please sign in to comment.