Skip to content

Feature/pickle rasterio #2131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 7, 2018
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ Enhancements
Bug fixes
~~~~~~~~~

- Fixed a bug in ``rasterio`` backend which prevented use with ``distributed``.
The ``rasterio`` backend now returns pickleable objects (:issue:`2021`).

.. _whats-new.0.10.6:

v0.10.6 (31 May 2018)
Expand Down Expand Up @@ -220,6 +223,7 @@ Bug fixes
By `Deepak Cherian <https://github.com/dcherian>`_.
- Colorbar limits are now determined by excluding ±Infs too.
By `Deepak Cherian <https://github.com/dcherian>`_.
By `Joe Hamman <https://github.com/jhamman>`_.
- Fixed ``to_iris`` to maintain lazy dask array after conversion (:issue:`2046`).
By `Alex Hilson <https://github.com/AlexHilson>`_ and `Stephan Hoyer <https://github.com/shoyer>`_.

Expand Down
28 changes: 28 additions & 0 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import traceback
import warnings
from collections import Mapping, OrderedDict
from functools import partial

import numpy as np

Expand Down Expand Up @@ -507,3 +508,30 @@ def assert_open(self):
if not self._isopen:
raise AssertionError('internal failure: file must be open '
'if `autoclose=True` is used.')


class PickleByReconstructionWrapper(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to add a unit test verifies that this works properly independently of any concrete datastore.

Maybe something simple with open()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 8b286c0


def __init__(self, opener, file, mode='r', **kwargs):
self.opener = partial(opener, file, mode=mode, **kwargs)
self.mode = mode
self._ds = None

@property
def value(self):
self._ds = self.opener()
return self._ds

def __getstate__(self):
state = self.__dict__.copy()
del state['_ds']
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we close the file here? state['_ds'].close()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be used after it is pickled.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK, maybe I misunderstood. I thought the purpose of this wrapper was to dereference (and close) the file before 'dump()' and re-open it after pickle.load().

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The basic idea is that most open file handles can't be pickled so we need to provide a mechanism to remove the existing handle and generate a new one in the dump/load steps in the pickling. In most cases, we do want to keep the original file open.

if self.mode == 'w':
# file has already been created, don't override when restoring
state['mode'] = 'a'
return state

def __setstate__(self, state):
self.__dict__.update(state)

def close(self):
self._ds.close()
56 changes: 29 additions & 27 deletions xarray/backends/rasterio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .. import DataArray
from ..core import indexing
from ..core.utils import is_scalar
from .common import BackendArray
from .common import BackendArray, PickleByReconstructionWrapper

try:
from dask.utils import SerializableLock as Lock
Expand All @@ -25,15 +25,15 @@
class RasterioArrayWrapper(BackendArray):
"""A wrapper around rasterio dataset objects"""

def __init__(self, rasterio_ds):
self.rasterio_ds = rasterio_ds
self._shape = (rasterio_ds.count, rasterio_ds.height,
rasterio_ds.width)
def __init__(self, riods):
self.riods = riods
self._shape = (riods.value.count, riods.value.height,
riods.value.width)
self._ndims = len(self.shape)

@property
def dtype(self):
dtypes = self.rasterio_ds.dtypes
dtypes = self.riods.value.dtypes
if not np.all(np.asarray(dtypes) == dtypes[0]):
raise ValueError('All bands should have the same dtype')
return np.dtype(dtypes[0])
Expand Down Expand Up @@ -105,7 +105,7 @@ def _get_indexer(self, key):
def __getitem__(self, key):
band_key, window, squeeze_axis, np_inds = self._get_indexer(key)

out = self.rasterio_ds.read(band_key, window=tuple(window))
out = self.riods.value.read(band_key, window=tuple(window))
if squeeze_axis:
out = np.squeeze(out, axis=squeeze_axis)
return indexing.NumpyIndexingAdapter(out)[np_inds]
Expand Down Expand Up @@ -194,28 +194,29 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
"""

import rasterio
riods = rasterio.open(filename, mode='r')

riods = PickleByReconstructionWrapper(rasterio.open, filename, mode='r')

if cache is None:
cache = chunks is None

coords = OrderedDict()

# Get bands
if riods.count < 1:
if riods.value.count < 1:
raise ValueError('Unknown dims')
coords['band'] = np.asarray(riods.indexes)
coords['band'] = np.asarray(riods.value.indexes)

# Get coordinates
if LooseVersion(rasterio.__version__) < '1.0':
transform = riods.affine
transform = riods.value.affine
else:
transform = riods.transform
transform = riods.value.transform
if transform.is_rectilinear:
# 1d coordinates
parse = True if parse_coordinates is None else parse_coordinates
if parse:
nx, ny = riods.width, riods.height
nx, ny = riods.value.width, riods.value.height
# xarray coordinates are pixel centered
x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform
_, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform
Expand All @@ -238,41 +239,42 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
# For serialization store as tuple of 6 floats, the last row being
# always (0, 0, 1) per definition (see https://github.com/sgillies/affine)
attrs['transform'] = tuple(transform)[:6]
if hasattr(riods, 'crs') and riods.crs:
if hasattr(riods.value, 'crs') and riods.value.crs:
# CRS is a dict-like object specific to rasterio
# If CRS is not None, we convert it back to a PROJ4 string using
# rasterio itself
attrs['crs'] = riods.crs.to_string()
if hasattr(riods, 'res'):
attrs['crs'] = riods.value.crs.to_string()
if hasattr(riods.value, 'res'):
# (width, height) tuple of pixels in units of CRS
attrs['res'] = riods.res
if hasattr(riods, 'is_tiled'):
attrs['res'] = riods.value.res
if hasattr(riods.value, 'is_tiled'):
# Is the TIF tiled? (bool)
# We cast it to an int for netCDF compatibility
attrs['is_tiled'] = np.uint8(riods.is_tiled)
attrs['is_tiled'] = np.uint8(riods.value.is_tiled)
with warnings.catch_warnings():
# casting riods.transform to a tuple makes this future proof
# casting riods.value.transform to a tuple makes this future proof
warnings.simplefilter('ignore', FutureWarning)
if hasattr(riods, 'transform'):
if hasattr(riods.value, 'transform'):
# Affine transformation matrix (tuple of floats)
# Describes coefficients mapping pixel coordinates to CRS
attrs['transform'] = tuple(riods.transform)
if hasattr(riods, 'nodatavals'):
attrs['transform'] = tuple(riods.value.transform)
if hasattr(riods.value, 'nodatavals'):
# The nodata values for the raster bands
attrs['nodatavals'] = tuple([np.nan if nodataval is None else nodataval
for nodataval in riods.nodatavals])
for nodataval in riods.value.nodatavals])

# Parse extra metadata from tags, if supported
parsers = {'ENVI': _parse_envi}

driver = riods.driver
driver = riods.value.driver
if driver in parsers:
meta = parsers[driver](riods.tags(ns=driver))
meta = parsers[driver](riods.value.tags(ns=driver))

for k, v in meta.items():
# Add values as coordinates if they match the band count,
# as attributes otherwise
if isinstance(v, (list, np.ndarray)) and len(v) == riods.count:
if (isinstance(v, (list, np.ndarray)) and
len(v) == riods.value.count):
coords[k] = ('band', np.asarray(v))
else:
attrs[k] = v
Expand Down
40 changes: 38 additions & 2 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from xarray import (
DataArray, Dataset, backends, open_dataarray, open_dataset, open_mfdataset,
save_mfdataset)
from xarray.backends.common import robust_getitem
from xarray.backends.common import (robust_getitem,
PickleByReconstructionWrapper)
from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding
from xarray.backends.pydap_ import PydapDataStore
from xarray.core import indexing
Expand Down Expand Up @@ -2724,7 +2725,8 @@ def create_tmp_geotiff(nx=4, ny=3, nz=3,
# yields a temporary geotiff file and a corresponding expected DataArray
import rasterio
from rasterio.transform import from_origin
with create_tmp_file(suffix='.tif') as tmp_file:
with create_tmp_file(suffix='.tif',
allow_cleanup_failure=ON_WINDOWS) as tmp_file:
# allow 2d or 3d shapes
if nz == 1:
data_shape = ny, nx
Expand Down Expand Up @@ -2996,6 +2998,14 @@ def test_chunks(self):
ex = expected.sel(band=1).mean(dim='x')
assert_allclose(ac, ex)

def test_pickle_rasterio(self):
# regression test for https://github.com/pydata/xarray/issues/2121
with create_tmp_geotiff() as (tmp_file, expected):
with xr.open_rasterio(tmp_file) as rioda:
temp = pickle.dumps(rioda)
with pickle.loads(temp) as actual:
assert_equal(actual, rioda)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to also add an integration test reading rasterio data with dask.distributed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 55a3abc.


def test_ENVI_tags(self):
rasterio = pytest.importorskip('rasterio', minversion='1.0a')
from rasterio.transform import from_origin
Expand Down Expand Up @@ -3260,3 +3270,29 @@ def test_dataarray_to_netcdf_no_name_pathlib(self):

with open_dataarray(tmp) as loaded_da:
assert_identical(original_da, loaded_da)


def test_pickle_reconstructor():

lines = ['foo bar spam eggs']

with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp:
with open(tmp, 'w') as f:
f.writelines(lines)

obj = PickleByReconstructionWrapper(open, tmp)

assert obj.value.readlines() == lines

p_obj = pickle.dumps(obj)
obj.value.close() # for windows
obj2 = pickle.loads(p_obj)

assert obj2.value.readlines() == lines

# roundtrip again to make sure we can fully restore the state
p_obj2 = pickle.dumps(obj2)
obj2.value.close() # for windows
obj3 = pickle.loads(p_obj2)

assert obj3.value.readlines() == lines
18 changes: 15 additions & 3 deletions xarray/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
from distributed.client import futures_of

import xarray as xr
from xarray.tests.test_backends import ON_WINDOWS, create_tmp_file
from xarray.tests.test_backends import (ON_WINDOWS, create_tmp_file,
create_tmp_geotiff)
from xarray.tests.test_dataset import create_test_data
from xarray.backends.common import HDF5_LOCK, CombinedLock

from . import (
assert_allclose, has_h5netcdf, has_netCDF4, has_scipy, requires_zarr,
raises_regex)
assert_allclose, has_h5netcdf, has_netCDF4, requires_rasterio, has_scipy,
requires_zarr, raises_regex)

# this is to stop isort throwing errors. May have been easier to just use
# `isort:skip` in retrospect
Expand Down Expand Up @@ -136,6 +137,17 @@ def test_dask_distributed_zarr_integration_test(loop):
assert_allclose(original, computed)


@requires_rasterio
def test_dask_distributed_rasterio_integration_test(loop):
with create_tmp_geotiff() as (tmp_file, expected):
with cluster() as (s, [a, b]):
with Client(s['address'], loop=loop) as c:
da_tiff = xr.open_rasterio(tmp_file, chunks={'band': 1})
assert isinstance(da_tiff.data, da.Array)
actual = da_tiff.compute()
assert_allclose(actual, expected)


@pytest.mark.skipif(distributed.__version__ <= '1.19.3',
reason='Need recent distributed version to clean up get')
@gen_cluster(client=True, timeout=None)
Expand Down