diff --git a/doc/dask.rst b/doc/dask.rst index 824f30aba4f..8fc0f655023 100644 --- a/doc/dask.rst +++ b/doc/dask.rst @@ -100,6 +100,14 @@ Once you've manipulated a dask array, you can still write a dataset too big to fit into memory back to disk by using :py:meth:`~xarray.Dataset.to_netcdf` in the usual way. +.. note:: + + When using dask's distributed scheduler to write NETCDF4 files, + it may be necessary to set the environment variable `HDF5_USE_FILE_LOCKING=FALSE` + to avoid competing locks within the HDF5 SWMR file locking scheme. Note that + writing netCDF files with dask's distributed scheduler is only supported for + the `netcdf4` backend. + A dataset can also be converted to a dask DataFrame using :py:meth:`~xarray.Dataset.to_dask_dataframe`. .. ipython:: python diff --git a/doc/io.rst b/doc/io.rst index c177496f6f2..c14e1516b38 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -672,9 +672,9 @@ files into a single Dataset by making use of :py:func:`~xarray.concat`. .. note:: - Version 0.5 includes support for manipulating datasets that - don't fit into memory with dask_. If you have dask installed, you can open - multiple files simultaneously using :py:func:`~xarray.open_mfdataset`:: + Xarray includes support for manipulating datasets that don't fit into memory + with dask_. If you have dask installed, you can open multiple files + simultaneously using :py:func:`~xarray.open_mfdataset`:: xr.open_mfdataset('my/files/*.nc') diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 57ae66818d3..eba9a16c711 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,13 @@ Documentation Enhancements ~~~~~~~~~~~~ +- Support for writing xarray datasets to netCDF files (netcdf4 backend only) + when using the `dask.distributed `_ + scheduler (:issue:`1464`). + By `Joe Hamman `_. + + +- Fixed to_netcdf when using dask distributed - Support lazy vectorized-indexing. After this change, flexible indexing such as orthogonal/vectorized indexing, becomes possible for all the backend arrays. Also, lazy ``transpose`` is now also supported. (:issue:`1897`) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 9d0b95c8c81..a22356f66b0 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -12,7 +12,8 @@ from ..core.combine import auto_combine from ..core.pycompat import basestring, path_type from ..core.utils import close_on_error, is_remote_uri -from .common import GLOBAL_LOCK, ArrayWriter +from .common import ( + HDF5_LOCK, ArrayWriter, CombinedLock, get_scheduler, get_scheduler_lock) DATAARRAY_NAME = '__xarray_dataarray_name__' DATAARRAY_VARIABLE = '__xarray_dataarray_variable__' @@ -64,9 +65,9 @@ def _default_lock(filename, engine): else: # TODO: identify netcdf3 files and don't use the global lock # for them - lock = GLOBAL_LOCK + lock = HDF5_LOCK elif engine in {'h5netcdf', 'pynio'}: - lock = GLOBAL_LOCK + lock = HDF5_LOCK else: lock = False return lock @@ -129,6 +130,20 @@ def _protect_dataset_variables_inplace(dataset, cache): variable.data = data +def _get_lock(engine, scheduler, format, path_or_file): + """ Get the lock(s) that apply to a particular scheduler/engine/format""" + + locks = [] + if format in ['NETCDF4', None] and engine in ['h5netcdf', 'netcdf4']: + locks.append(HDF5_LOCK) + locks.append(get_scheduler_lock(scheduler, path_or_file)) + + # When we have more than one lock, use the CombinedLock wrapper class + lock = CombinedLock(locks) if len(locks) > 1 else locks[0] + + return lock + + def open_dataset(filename_or_obj, group=None, decode_cf=True, mask_and_scale=True, decode_times=True, autoclose=False, concat_characters=True, decode_coords=True, engine=None, @@ -620,8 +635,20 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, # if a writer is provided, store asynchronously sync = writer is None + # handle scheduler specific logic + scheduler = get_scheduler() + if (dataset.chunks and scheduler in ['distributed', 'multiprocessing'] and + engine != 'netcdf4'): + raise NotImplementedError("Writing netCDF files with the %s backend " + "is not currently supported with dask's %s " + "scheduler" % (engine, scheduler)) + lock = _get_lock(engine, scheduler, format, path_or_file) + autoclose = (dataset.chunks and + scheduler in ['distributed', 'multiprocessing']) + target = path_or_file if path_or_file is not None else BytesIO() - store = store_open(target, mode, format, group, writer) + store = store_open(target, mode, format, group, writer, + autoclose=autoclose, lock=lock) if unlimited_dims is None: unlimited_dims = dataset.encoding.get('unlimited_dims', None) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index d91cedbbda3..c46f9d5b552 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -2,6 +2,8 @@ import contextlib import logging +import multiprocessing +import threading import time import traceback import warnings @@ -14,11 +16,12 @@ from ..core.pycompat import dask_array_type, iteritems from ..core.utils import FrozenOrderedDict, NdimSizeLenMixin +# Import default lock try: - from dask.utils import SerializableLock as Lock + from dask.utils import SerializableLock + HDF5_LOCK = SerializableLock() except ImportError: - from threading import Lock - + HDF5_LOCK = threading.Lock() # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -27,8 +30,54 @@ NONE_VAR_NAME = '__values__' -# dask.utils.SerializableLock if available, otherwise just a threading.Lock -GLOBAL_LOCK = Lock() +def get_scheduler(get=None, collection=None): + """ Determine the dask scheduler that is being used. + + None is returned if not dask scheduler is active. + + See also + -------- + dask.utils.effective_get + """ + try: + from dask.utils import effective_get + actual_get = effective_get(get, collection) + try: + from dask.distributed import Client + if isinstance(actual_get.__self__, Client): + return 'distributed' + except (ImportError, AttributeError): + try: + import dask.multiprocessing + if actual_get == dask.multiprocessing.get: + return 'multiprocessing' + else: + return 'threaded' + except ImportError: + return 'threaded' + except ImportError: + return None + + +def get_scheduler_lock(scheduler, path_or_file=None): + """ Get the appropriate lock for a certain situation based onthe dask + scheduler used. + + See Also + -------- + dask.utils.get_scheduler_lock + """ + + if scheduler == 'distributed': + from dask.distributed import Lock + return Lock(path_or_file) + elif scheduler == 'multiprocessing': + return multiprocessing.Lock() + elif scheduler == 'threaded': + from dask.utils import SerializableLock + return SerializableLock() + else: + return threading.Lock() def _encode_variable_name(name): @@ -77,6 +126,39 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, time.sleep(1e-3 * next_delay) +class CombinedLock(object): + """A combination of multiple locks. + + Like a locked door, a CombinedLock is locked if any of its constituent + locks are locked. + """ + + def __init__(self, locks): + self.locks = tuple(set(locks)) # remove duplicates + + def acquire(self, *args): + return all(lock.acquire(*args) for lock in self.locks) + + def release(self, *args): + for lock in self.locks: + lock.release(*args) + + def __enter__(self): + for lock in self.locks: + lock.__enter__() + + def __exit__(self, *args): + for lock in self.locks: + lock.__exit__(*args) + + @property + def locked(self): + return any(lock.locked for lock in self.locks) + + def __repr__(self): + return "CombinedLock(%r)" % list(self.locks) + + class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): def __array__(self, dtype=None): @@ -85,7 +167,9 @@ def __array__(self, dtype=None): class AbstractDataStore(Mapping): - _autoclose = False + _autoclose = None + _ds = None + _isopen = False def __iter__(self): return iter(self.variables) @@ -168,7 +252,7 @@ def __exit__(self, exception_type, exception_value, traceback): class ArrayWriter(object): - def __init__(self, lock=GLOBAL_LOCK): + def __init__(self, lock=HDF5_LOCK): self.sources = [] self.targets = [] self.lock = lock @@ -178,11 +262,7 @@ def add(self, source, target): self.sources.append(source) self.targets.append(target) else: - try: - target[...] = source - except TypeError: - # workaround for GH: scipy/scipy#6880 - target[:] = source + target[...] = source def sync(self): if self.sources: @@ -193,9 +273,9 @@ def sync(self): class AbstractWritableDataStore(AbstractDataStore): - def __init__(self, writer=None): + def __init__(self, writer=None, lock=HDF5_LOCK): if writer is None: - writer = ArrayWriter() + writer = ArrayWriter(lock=lock) self.writer = writer def encode(self, variables, attributes): @@ -239,6 +319,9 @@ def set_variable(self, k, v): # pragma: no cover raise NotImplementedError def sync(self): + if self._isopen and self._autoclose: + # datastore will be reopened during write + self.close() self.writer.sync() def store_dataset(self, dataset): @@ -373,7 +456,8 @@ class DataStorePickleMixin(object): def __getstate__(self): state = self.__dict__.copy() - del state['ds'] + del state['_ds'] + del state['_isopen'] if self._mode == 'w': # file has already been created, don't override when restoring state['_mode'] = 'a' @@ -381,19 +465,32 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__.update(state) - self.ds = self._opener(mode=self._mode) + self._ds = None + self._isopen = False + + @property + def ds(self): + if self._ds is not None and self._isopen: + return self._ds + ds = self._opener(mode=self._mode) + self._isopen = True + return ds @contextlib.contextmanager - def ensure_open(self, autoclose): + def ensure_open(self, autoclose=None): """ Helper function to make sure datasets are closed and opened at appropriate times to avoid too many open file errors. Use requires `autoclose=True` argument to `open_mfdataset`. """ - if self._autoclose and not self._isopen: + + if autoclose is None: + autoclose = self._autoclose + + if not self._isopen: try: - self.ds = self._opener() + self._ds = self._opener() self._isopen = True yield finally: diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 1d166f05eb1..7beda03308e 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -8,7 +8,8 @@ from ..core import indexing from ..core.pycompat import OrderedDict, bytes_type, iteritems, unicode_type from ..core.utils import FrozenOrderedDict, close_on_error -from .common import DataStorePickleMixin, WritableCFDataStore, find_root +from .common import ( + HDF5_LOCK, DataStorePickleMixin, WritableCFDataStore, find_root) from .netCDF4_ import ( BaseNetCDF4Array, _encode_nc4_variable, _extract_nc4_variable_encoding, _get_datatype, _nc4_group) @@ -68,12 +69,12 @@ class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin): """ def __init__(self, filename, mode='r', format=None, group=None, - writer=None, autoclose=False): + writer=None, autoclose=False, lock=HDF5_LOCK): if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') opener = functools.partial(_open_h5netcdf_group, filename, mode=mode, group=group) - self.ds = opener() + self._ds = opener() if autoclose: raise NotImplementedError('autoclose=True is not implemented ' 'for the h5netcdf backend pending ' @@ -85,7 +86,7 @@ def __init__(self, filename, mode='r', format=None, group=None, self._opener = opener self._filename = filename self._mode = mode - super(H5NetCDFStore, self).__init__(writer) + super(H5NetCDFStore, self).__init__(writer, lock=lock) def open_store_variable(self, name, var): with self.ensure_open(autoclose=False): @@ -177,7 +178,10 @@ def prepare_variable(self, name, variable, check_encoding=False, for k, v in iteritems(attrs): nc4_var.setncattr(k, v) - return nc4_var, variable.data + + target = H5NetCDFArrayWrapper(name, self) + + return target, variable.data def sync(self): with self.ensure_open(autoclose=True): diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 4903e9a98f2..01d1a4de5f5 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -13,8 +13,8 @@ from ..core.pycompat import PY3, OrderedDict, basestring, iteritems, suppress from ..core.utils import FrozenOrderedDict, close_on_error, is_remote_uri from .common import ( - BackendArray, DataStorePickleMixin, WritableCFDataStore, find_root, - robust_getitem) + HDF5_LOCK, BackendArray, DataStorePickleMixin, WritableCFDataStore, + find_root, robust_getitem) from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable # This lookup table maps from dtype.byteorder to a readable endian @@ -41,6 +41,11 @@ def __init__(self, variable_name, datastore): dtype = np.dtype('O') self.dtype = dtype + def __setitem__(self, key, value): + with self.datastore.ensure_open(autoclose=True): + data = self.get_array() + data[key] = value + def get_array(self): self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] @@ -231,14 +236,14 @@ class NetCDF4DataStore(WritableCFDataStore, DataStorePickleMixin): """ def __init__(self, netcdf4_dataset, mode='r', writer=None, opener=None, - autoclose=False): + autoclose=False, lock=HDF5_LOCK): if autoclose and opener is None: raise ValueError('autoclose requires an opener') _disable_auto_decode_group(netcdf4_dataset) - self.ds = netcdf4_dataset + self._ds = netcdf4_dataset self._autoclose = autoclose self._isopen = True self.format = self.ds.data_model @@ -249,12 +254,12 @@ def __init__(self, netcdf4_dataset, mode='r', writer=None, opener=None, self._opener = functools.partial(opener, mode=self._mode) else: self._opener = opener - super(NetCDF4DataStore, self).__init__(writer) + super(NetCDF4DataStore, self).__init__(writer, lock=lock) @classmethod def open(cls, filename, mode='r', format='NETCDF4', group=None, writer=None, clobber=True, diskless=False, persist=False, - autoclose=False): + autoclose=False, lock=HDF5_LOCK): import netCDF4 as nc4 if (len(filename) == 88 and LooseVersion(nc4.__version__) < "1.3.1"): @@ -274,7 +279,7 @@ def open(cls, filename, mode='r', format='NETCDF4', group=None, format=format) ds = opener() return cls(ds, mode=mode, writer=writer, opener=opener, - autoclose=autoclose) + autoclose=autoclose, lock=lock) def open_store_variable(self, name, var): with self.ensure_open(autoclose=False): @@ -399,7 +404,9 @@ def prepare_variable(self, name, variable, check_encoding=False, # OrderedDict as the input to setncatts nc4_var.setncattr(k, v) - return nc4_var, variable.data + target = NetCDF4ArrayWrapper(name, self) + + return target, variable.data def sync(self): with self.ensure_open(autoclose=True): diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 95226e453b4..3c638b6b057 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -46,14 +46,14 @@ class NioDataStore(AbstractDataStore, DataStorePickleMixin): def __init__(self, filename, mode='r', autoclose=False): import Nio opener = functools.partial(Nio.open_file, filename, mode=mode) - self.ds = opener() - # xarray provides its own support for FillValue, - # so turn off PyNIO's support for the same. - self.ds.set_option('MaskedArrayMode', 'MaskedNever') + self._ds = opener() self._autoclose = autoclose self._isopen = True self._opener = opener self._mode = mode + # xarray provides its own support for FillValue, + # so turn off PyNIO's support for the same. + self.ds.set_option('MaskedArrayMode', 'MaskedNever') def open_store_variable(self, name, var): data = indexing.LazilyOuterIndexedArray(NioArrayWrapper(name, self)) diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index a0765fe27bd..ee2c0fbf106 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -2,6 +2,7 @@ import functools import warnings +from distutils.version import LooseVersion from io import BytesIO import numpy as np @@ -53,6 +54,18 @@ def __getitem__(self, key): copy = self.datastore.ds.use_mmap return np.array(data, dtype=self.dtype, copy=copy) + def __setitem__(self, key, value): + with self.datastore.ensure_open(autoclose=True): + data = self.datastore.ds.variables[self.variable_name] + try: + data[key] = value + except TypeError: + if key is Ellipsis: + # workaround for GH: scipy/scipy#6880 + data[:] = value + else: + raise + def _open_scipy_netcdf(filename, mode, mmap, version): import scipy.io @@ -103,11 +116,12 @@ class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): """ def __init__(self, filename_or_obj, mode='r', format=None, group=None, - writer=None, mmap=None, autoclose=False): + writer=None, mmap=None, autoclose=False, lock=None): import scipy import scipy.io - if mode != 'r' and scipy.__version__ < '0.13': # pragma: no cover + if (mode != 'r' and + scipy.__version__ < LooseVersion('0.13')): # pragma: no cover warnings.warn('scipy %s detected; ' 'the minimal recommended version is 0.13. ' 'Older version of this library do not reliably ' @@ -129,13 +143,13 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, opener = functools.partial(_open_scipy_netcdf, filename=filename_or_obj, mode=mode, mmap=mmap, version=version) - self.ds = opener() + self._ds = opener() self._autoclose = autoclose self._isopen = True self._opener = opener self._mode = mode - super(ScipyDataStore, self).__init__(writer) + super(ScipyDataStore, self).__init__(writer, lock=lock) def open_store_variable(self, name, var): with self.ensure_open(autoclose=False): @@ -200,7 +214,10 @@ def prepare_variable(self, name, variable, check_encoding=False, for k, v in iteritems(variable.attrs): self._validate_attr_key(k) setattr(scipy_var, k, v) - return scipy_var, data + + target = ScipyArrayWrapper(name, self) + + return target, data def sync(self): with self.ensure_open(autoclose=True): @@ -221,4 +238,5 @@ def __setstate__(self, state): # seek to the start of the file so scipy can read it filename.seek(0) super(ScipyDataStore, self).__setstate__(state) - self._isopen = True + self._ds = None + self._isopen = False diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 8797e3104a1..71ce965f368 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -341,6 +341,8 @@ def prepare_variable(self, name, variable, check_encoding=False, fill_value = _ensure_valid_fill_value(attrs.pop('_FillValue', None), dtype) + if variable.encoding == {'_FillValue': None} and fill_value is None: + variable.encoding = {} encoding = _extract_zarr_variable_encoding( variable, raise_on_invalid=check_encoding) @@ -361,6 +363,9 @@ def store(self, variables, attributes, *args, **kwargs): AbstractWritableDataStore.store(self, variables, attributes, *args, **kwargs) + def sync(self): + self.writer.sync() + def open_zarr(store, group=None, synchronizer=None, auto_chunk=True, decode_cf=True, mask_and_scale=True, decode_times=True, diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 149009689e9..8c1c58e9a40 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -8,8 +8,8 @@ from .alignment import align from .merge import merge from .pycompat import OrderedDict, basestring, iteritems -from .variable import IndexVariable, Variable, as_variable from .variable import concat as concat_vars +from .variable import IndexVariable, Variable, as_variable def concat(objs, dim=None, data_vars='all', coords='different', diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 0d060069477..0ac03327494 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -1,6 +1,9 @@ """ isort:skip_file """ - +from __future__ import absolute_import, division, print_function +import os import sys +import pickle +import tempfile import pytest @@ -8,6 +11,7 @@ distributed = pytest.importorskip('distributed') # isort:skip from dask import array +from dask.distributed import Client, Lock from distributed.utils_test import cluster, gen_cluster from distributed.utils_test import loop # flake8: noqa from distributed.client import futures_of @@ -15,9 +19,11 @@ import xarray as xr from xarray.tests.test_backends import ON_WINDOWS, create_tmp_file from xarray.tests.test_dataset import create_test_data +from xarray.backends.common import HDF5_LOCK, CombinedLock from . import ( - assert_allclose, has_h5netcdf, has_netCDF4, has_scipy, requires_zarr) + assert_allclose, has_h5netcdf, has_netCDF4, has_scipy, requires_zarr, + raises_regex) # this is to stop isort throwing errors. May have been easier to just use # `isort:skip` in retrospect @@ -34,29 +40,95 @@ if has_h5netcdf: ENGINES.append('h5netcdf') +NC_FORMATS = {'netcdf4': ['NETCDF3_CLASSIC', 'NETCDF3_64BIT_OFFSET', + 'NETCDF3_64BIT_DATA', 'NETCDF4_CLASSIC', 'NETCDF4'], + 'scipy': ['NETCDF3_CLASSIC', 'NETCDF3_64BIT'], + 'h5netcdf': ['NETCDF4']} +TEST_FORMATS = ['NETCDF3_CLASSIC', 'NETCDF4_CLASSIC', 'NETCDF4'] + + +@pytest.mark.xfail(sys.platform == 'win32', + reason='https://github.com/pydata/xarray/issues/1738') +@pytest.mark.parametrize('engine', ['netcdf4']) +@pytest.mark.parametrize('autoclose', [True, False]) +@pytest.mark.parametrize('nc_format', TEST_FORMATS) +def test_dask_distributed_netcdf_roundtrip(monkeypatch, loop, + engine, autoclose, nc_format): + + monkeypatch.setenv('HDF5_USE_FILE_LOCKING', 'FALSE') + + chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} + + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + + original = create_test_data().chunk(chunks) + original.to_netcdf(filename, engine=engine, format=nc_format) + + with xr.open_dataset(filename, + chunks=chunks, + engine=engine, + autoclose=autoclose) as restored: + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_allclose(original, computed) + @pytest.mark.xfail(sys.platform == 'win32', reason='https://github.com/pydata/xarray/issues/1738') @pytest.mark.parametrize('engine', ENGINES) -def test_dask_distributed_netcdf_integration_test(loop, engine): - with cluster() as (s, _): - with distributed.Client(s['address'], loop=loop): - original = create_test_data() - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - original.to_netcdf(filename, engine=engine) - with xr.open_dataset( - filename, chunks=3, engine=engine) as restored: +@pytest.mark.parametrize('autoclose', [True, False]) +@pytest.mark.parametrize('nc_format', TEST_FORMATS) +def test_dask_distributed_read_netcdf_integration_test(loop, engine, autoclose, + nc_format): + + if engine == 'h5netcdf' and autoclose: + pytest.skip('h5netcdf does not support autoclose') + + if nc_format not in NC_FORMATS[engine]: + pytest.skip('invalid format for engine') + + chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} + + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + + original = create_test_data() + original.to_netcdf(filename, engine=engine, format=nc_format) + + with xr.open_dataset(filename, + chunks=chunks, + engine=engine, + autoclose=autoclose) as restored: assert isinstance(restored.var1.data, da.Array) computed = restored.compute() assert_allclose(original, computed) +@pytest.mark.parametrize('engine', ['h5netcdf', 'scipy']) +def test_dask_distributed_netcdf_integration_test_not_implemented(loop, engine): + chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} + + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + + original = create_test_data().chunk(chunks) + + with raises_regex(NotImplementedError, 'distributed'): + original.to_netcdf(filename, engine=engine) + + @requires_zarr def test_dask_distributed_zarr_integration_test(loop): - with cluster() as (s, _): - with distributed.Client(s['address'], loop=loop): - original = create_test_data() - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: + chunks = {'dim1': 4, 'dim2': 3, 'dim3': 5} + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + original = create_test_data().chunk(chunks) + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS, + suffix='.zarr') as filename: original.to_zarr(filename) with xr.open_zarr(filename) as restored: assert isinstance(restored.var1.data, da.Array) @@ -92,3 +164,25 @@ def test_async(c, s, a, b): assert_allclose(x + 10, w) assert s.tasks + + +def test_hdf5_lock(): + assert isinstance(HDF5_LOCK, dask.utils.SerializableLock) + + +@gen_cluster(client=True) +def test_serializable_locks(c, s, a, b): + def f(x, lock=None): + with lock: + return x + 1 + + # note, the creation of Lock needs to be done inside a cluster + for lock in [HDF5_LOCK, Lock(), Lock('filename.nc'), + CombinedLock([HDF5_LOCK]), + CombinedLock([HDF5_LOCK, Lock('filename.nc')])]: + + futures = c.map(f, list(range(10)), lock=lock) + yield c.gather(futures) + + lock2 = pickle.loads(pickle.dumps(lock)) + assert type(lock) == type(lock2)