Skip to content

Commit 21a9f3d

Browse files
author
Joe Hamman
authored
Feature/pickle rasterio (#2131)
* add regression test * add PickleByReconstructionWrapper * docs * load in context manager * add distributed integration test * add test_pickle_reconstructor * drop lazy opening/caching and use partial function for open * stop using clever getattr hack * allow_cleanup_failure=ON_WINDOWS in tests for windows * whats new fix * fix bug in multiple pickles * fix for windows
1 parent bc52f8a commit 21a9f3d

File tree

5 files changed

+114
-32
lines changed

5 files changed

+114
-32
lines changed

doc/whats-new.rst

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ Enhancements
4141
Bug fixes
4242
~~~~~~~~~
4343

44+
- Fixed a bug in ``rasterio`` backend which prevented use with ``distributed``.
45+
The ``rasterio`` backend now returns pickleable objects (:issue:`2021`).
46+
4447
.. _whats-new.0.10.6:
4548

4649
v0.10.6 (31 May 2018)
@@ -220,6 +223,7 @@ Bug fixes
220223
By `Deepak Cherian <https://github.com/dcherian>`_.
221224
- Colorbar limits are now determined by excluding ±Infs too.
222225
By `Deepak Cherian <https://github.com/dcherian>`_.
226+
By `Joe Hamman <https://github.com/jhamman>`_.
223227
- Fixed ``to_iris`` to maintain lazy dask array after conversion (:issue:`2046`).
224228
By `Alex Hilson <https://github.com/AlexHilson>`_ and `Stephan Hoyer <https://github.com/shoyer>`_.
225229

xarray/backends/common.py

+28
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import traceback
99
import warnings
1010
from collections import Mapping, OrderedDict
11+
from functools import partial
1112

1213
import numpy as np
1314

@@ -507,3 +508,30 @@ def assert_open(self):
507508
if not self._isopen:
508509
raise AssertionError('internal failure: file must be open '
509510
'if `autoclose=True` is used.')
511+
512+
513+
class PickleByReconstructionWrapper(object):
514+
515+
def __init__(self, opener, file, mode='r', **kwargs):
516+
self.opener = partial(opener, file, mode=mode, **kwargs)
517+
self.mode = mode
518+
self._ds = None
519+
520+
@property
521+
def value(self):
522+
self._ds = self.opener()
523+
return self._ds
524+
525+
def __getstate__(self):
526+
state = self.__dict__.copy()
527+
del state['_ds']
528+
if self.mode == 'w':
529+
# file has already been created, don't override when restoring
530+
state['mode'] = 'a'
531+
return state
532+
533+
def __setstate__(self, state):
534+
self.__dict__.update(state)
535+
536+
def close(self):
537+
self._ds.close()

xarray/backends/rasterio_.py

+29-27
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .. import DataArray
99
from ..core import indexing
1010
from ..core.utils import is_scalar
11-
from .common import BackendArray
11+
from .common import BackendArray, PickleByReconstructionWrapper
1212

1313
try:
1414
from dask.utils import SerializableLock as Lock
@@ -25,15 +25,15 @@
2525
class RasterioArrayWrapper(BackendArray):
2626
"""A wrapper around rasterio dataset objects"""
2727

28-
def __init__(self, rasterio_ds):
29-
self.rasterio_ds = rasterio_ds
30-
self._shape = (rasterio_ds.count, rasterio_ds.height,
31-
rasterio_ds.width)
28+
def __init__(self, riods):
29+
self.riods = riods
30+
self._shape = (riods.value.count, riods.value.height,
31+
riods.value.width)
3232
self._ndims = len(self.shape)
3333

3434
@property
3535
def dtype(self):
36-
dtypes = self.rasterio_ds.dtypes
36+
dtypes = self.riods.value.dtypes
3737
if not np.all(np.asarray(dtypes) == dtypes[0]):
3838
raise ValueError('All bands should have the same dtype')
3939
return np.dtype(dtypes[0])
@@ -105,7 +105,7 @@ def _get_indexer(self, key):
105105
def __getitem__(self, key):
106106
band_key, window, squeeze_axis, np_inds = self._get_indexer(key)
107107

108-
out = self.rasterio_ds.read(band_key, window=tuple(window))
108+
out = self.riods.value.read(band_key, window=tuple(window))
109109
if squeeze_axis:
110110
out = np.squeeze(out, axis=squeeze_axis)
111111
return indexing.NumpyIndexingAdapter(out)[np_inds]
@@ -194,28 +194,29 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
194194
"""
195195

196196
import rasterio
197-
riods = rasterio.open(filename, mode='r')
197+
198+
riods = PickleByReconstructionWrapper(rasterio.open, filename, mode='r')
198199

199200
if cache is None:
200201
cache = chunks is None
201202

202203
coords = OrderedDict()
203204

204205
# Get bands
205-
if riods.count < 1:
206+
if riods.value.count < 1:
206207
raise ValueError('Unknown dims')
207-
coords['band'] = np.asarray(riods.indexes)
208+
coords['band'] = np.asarray(riods.value.indexes)
208209

209210
# Get coordinates
210211
if LooseVersion(rasterio.__version__) < '1.0':
211-
transform = riods.affine
212+
transform = riods.value.affine
212213
else:
213-
transform = riods.transform
214+
transform = riods.value.transform
214215
if transform.is_rectilinear:
215216
# 1d coordinates
216217
parse = True if parse_coordinates is None else parse_coordinates
217218
if parse:
218-
nx, ny = riods.width, riods.height
219+
nx, ny = riods.value.width, riods.value.height
219220
# xarray coordinates are pixel centered
220221
x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform
221222
_, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform
@@ -238,41 +239,42 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
238239
# For serialization store as tuple of 6 floats, the last row being
239240
# always (0, 0, 1) per definition (see https://github.com/sgillies/affine)
240241
attrs['transform'] = tuple(transform)[:6]
241-
if hasattr(riods, 'crs') and riods.crs:
242+
if hasattr(riods.value, 'crs') and riods.value.crs:
242243
# CRS is a dict-like object specific to rasterio
243244
# If CRS is not None, we convert it back to a PROJ4 string using
244245
# rasterio itself
245-
attrs['crs'] = riods.crs.to_string()
246-
if hasattr(riods, 'res'):
246+
attrs['crs'] = riods.value.crs.to_string()
247+
if hasattr(riods.value, 'res'):
247248
# (width, height) tuple of pixels in units of CRS
248-
attrs['res'] = riods.res
249-
if hasattr(riods, 'is_tiled'):
249+
attrs['res'] = riods.value.res
250+
if hasattr(riods.value, 'is_tiled'):
250251
# Is the TIF tiled? (bool)
251252
# We cast it to an int for netCDF compatibility
252-
attrs['is_tiled'] = np.uint8(riods.is_tiled)
253+
attrs['is_tiled'] = np.uint8(riods.value.is_tiled)
253254
with warnings.catch_warnings():
254-
# casting riods.transform to a tuple makes this future proof
255+
# casting riods.value.transform to a tuple makes this future proof
255256
warnings.simplefilter('ignore', FutureWarning)
256-
if hasattr(riods, 'transform'):
257+
if hasattr(riods.value, 'transform'):
257258
# Affine transformation matrix (tuple of floats)
258259
# Describes coefficients mapping pixel coordinates to CRS
259-
attrs['transform'] = tuple(riods.transform)
260-
if hasattr(riods, 'nodatavals'):
260+
attrs['transform'] = tuple(riods.value.transform)
261+
if hasattr(riods.value, 'nodatavals'):
261262
# The nodata values for the raster bands
262263
attrs['nodatavals'] = tuple([np.nan if nodataval is None else nodataval
263-
for nodataval in riods.nodatavals])
264+
for nodataval in riods.value.nodatavals])
264265

265266
# Parse extra metadata from tags, if supported
266267
parsers = {'ENVI': _parse_envi}
267268

268-
driver = riods.driver
269+
driver = riods.value.driver
269270
if driver in parsers:
270-
meta = parsers[driver](riods.tags(ns=driver))
271+
meta = parsers[driver](riods.value.tags(ns=driver))
271272

272273
for k, v in meta.items():
273274
# Add values as coordinates if they match the band count,
274275
# as attributes otherwise
275-
if isinstance(v, (list, np.ndarray)) and len(v) == riods.count:
276+
if (isinstance(v, (list, np.ndarray)) and
277+
len(v) == riods.value.count):
276278
coords[k] = ('band', np.asarray(v))
277279
else:
278280
attrs[k] = v

xarray/tests/test_backends.py

+38-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from xarray import (
2020
DataArray, Dataset, backends, open_dataarray, open_dataset, open_mfdataset,
2121
save_mfdataset)
22-
from xarray.backends.common import robust_getitem
22+
from xarray.backends.common import (robust_getitem,
23+
PickleByReconstructionWrapper)
2324
from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding
2425
from xarray.backends.pydap_ import PydapDataStore
2526
from xarray.core import indexing
@@ -2724,7 +2725,8 @@ def create_tmp_geotiff(nx=4, ny=3, nz=3,
27242725
# yields a temporary geotiff file and a corresponding expected DataArray
27252726
import rasterio
27262727
from rasterio.transform import from_origin
2727-
with create_tmp_file(suffix='.tif') as tmp_file:
2728+
with create_tmp_file(suffix='.tif',
2729+
allow_cleanup_failure=ON_WINDOWS) as tmp_file:
27282730
# allow 2d or 3d shapes
27292731
if nz == 1:
27302732
data_shape = ny, nx
@@ -2996,6 +2998,14 @@ def test_chunks(self):
29962998
ex = expected.sel(band=1).mean(dim='x')
29972999
assert_allclose(ac, ex)
29983000

3001+
def test_pickle_rasterio(self):
3002+
# regression test for https://github.com/pydata/xarray/issues/2121
3003+
with create_tmp_geotiff() as (tmp_file, expected):
3004+
with xr.open_rasterio(tmp_file) as rioda:
3005+
temp = pickle.dumps(rioda)
3006+
with pickle.loads(temp) as actual:
3007+
assert_equal(actual, rioda)
3008+
29993009
def test_ENVI_tags(self):
30003010
rasterio = pytest.importorskip('rasterio', minversion='1.0a')
30013011
from rasterio.transform import from_origin
@@ -3260,3 +3270,29 @@ def test_dataarray_to_netcdf_no_name_pathlib(self):
32603270

32613271
with open_dataarray(tmp) as loaded_da:
32623272
assert_identical(original_da, loaded_da)
3273+
3274+
3275+
def test_pickle_reconstructor():
3276+
3277+
lines = ['foo bar spam eggs']
3278+
3279+
with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp:
3280+
with open(tmp, 'w') as f:
3281+
f.writelines(lines)
3282+
3283+
obj = PickleByReconstructionWrapper(open, tmp)
3284+
3285+
assert obj.value.readlines() == lines
3286+
3287+
p_obj = pickle.dumps(obj)
3288+
obj.value.close() # for windows
3289+
obj2 = pickle.loads(p_obj)
3290+
3291+
assert obj2.value.readlines() == lines
3292+
3293+
# roundtrip again to make sure we can fully restore the state
3294+
p_obj2 = pickle.dumps(obj2)
3295+
obj2.value.close() # for windows
3296+
obj3 = pickle.loads(p_obj2)
3297+
3298+
assert obj3.value.readlines() == lines

xarray/tests/test_distributed.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
from distributed.client import futures_of
1818

1919
import xarray as xr
20-
from xarray.tests.test_backends import ON_WINDOWS, create_tmp_file
20+
from xarray.tests.test_backends import (ON_WINDOWS, create_tmp_file,
21+
create_tmp_geotiff)
2122
from xarray.tests.test_dataset import create_test_data
2223
from xarray.backends.common import HDF5_LOCK, CombinedLock
2324

2425
from . import (
25-
assert_allclose, has_h5netcdf, has_netCDF4, has_scipy, requires_zarr,
26-
raises_regex)
26+
assert_allclose, has_h5netcdf, has_netCDF4, requires_rasterio, has_scipy,
27+
requires_zarr, raises_regex)
2728

2829
# this is to stop isort throwing errors. May have been easier to just use
2930
# `isort:skip` in retrospect
@@ -136,6 +137,17 @@ def test_dask_distributed_zarr_integration_test(loop):
136137
assert_allclose(original, computed)
137138

138139

140+
@requires_rasterio
141+
def test_dask_distributed_rasterio_integration_test(loop):
142+
with create_tmp_geotiff() as (tmp_file, expected):
143+
with cluster() as (s, [a, b]):
144+
with Client(s['address'], loop=loop) as c:
145+
da_tiff = xr.open_rasterio(tmp_file, chunks={'band': 1})
146+
assert isinstance(da_tiff.data, da.Array)
147+
actual = da_tiff.compute()
148+
assert_allclose(actual, expected)
149+
150+
139151
@pytest.mark.skipif(distributed.__version__ <= '1.19.3',
140152
reason='Need recent distributed version to clean up get')
141153
@gen_cluster(client=True, timeout=None)

0 commit comments

Comments
 (0)