Skip to content

Adding mask to open_rasterio #2618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions xarray/backends/rasterio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class RasterioArrayWrapper(BackendArray):
"""A wrapper around rasterio dataset objects"""

def __init__(self, manager, lock, vrt_params=None):
def __init__(self, manager, lock, vrt_params=None, masked=False):
from rasterio.vrt import WarpedVRT
self.manager = manager
self.lock = lock
Expand All @@ -39,6 +39,7 @@ def __init__(self, manager, lock, vrt_params=None):
if not np.all(np.asarray(dtypes) == dtypes[0]):
raise ValueError('All bands should have the same dtype')
self._dtype = np.dtype(dtypes[0])
Copy link
Contributor

@snowman2 snowman2 Jul 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self._dtype = np.dtype(dtypes[0])
self._dtype = np.float64 if self.masked else np.dtype(dtypes[0])

self.masked = masked

@property
def dtype(self):
Expand Down Expand Up @@ -119,7 +120,9 @@ def _getitem(self, key):
riods = self.manager.acquire(needs_lock=False)
if self.vrt_params is not None:
riods = WarpedVRT(riods, **self.vrt_params)
out = riods.read(band_key, window=window)
out = riods.read(band_key, window=window, masked=self.masked)
if self.masked:
out = np.ma.filled(out, np.nan)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
out = np.ma.filled(out, np.nan)
out = np.ma.filled(out.astype(np.float64), np.nan)


if squeeze_axis:
out = np.squeeze(out, axis=squeeze_axis)
Expand Down Expand Up @@ -162,7 +165,7 @@ def default(s):


def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
lock=None):
lock=None, masked=False):
"""Open a file with rasterio (experimental).

This should work with any file that rasterio can open (most often:
Expand Down Expand Up @@ -205,6 +208,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
:py:func:`dask.array.from_array`. By default, a global lock is
used to avoid issues with concurrent access to the same file when using
dask's multithreaded backend.
masked : bool, optional
If True, uses nodatavals to set values to NaN. Defaults to False.

Returns
-------
Expand Down Expand Up @@ -275,6 +280,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,

# Attributes
attrs = dict()
encoding = dict()
# Affine transformation matrix (always available)
# This describes coefficients mapping pixel coordinates to CRS
# For serialization store as tuple of 6 floats, the last row being
Expand All @@ -298,9 +304,13 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
attrs['is_tiled'] = np.uint8(riods.is_tiled)
if hasattr(riods, 'nodatavals'):
# The nodata values for the raster bands
attrs['nodatavals'] = tuple(
nodatavals = tuple(
np.nan if nodataval is None else nodataval
for nodataval in riods.nodatavals)
if masked:
encoding['nodatavals'] = nodatavals
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thing setting the encoding should inline with the other if statements and should be something like:

if mask and riods.nodata is not None:
    encoding['_FillValue'] = riods.nodata

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shoyer is this what you intended or did you want to keep it as @jsignell modified it to be?

else:
attrs['nodatavals'] = nodatavals

# Parse extra metadata from tags, if supported
parsers = {'ENVI': _parse_envi}
Expand All @@ -319,15 +329,15 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
attrs[k] = v

data = indexing.LazilyOuterIndexedArray(
RasterioArrayWrapper(manager, lock, vrt_params))
RasterioArrayWrapper(manager, lock, vrt_params, masked=masked))

# this lets you write arrays loaded with rasterio
data = indexing.CopyOnWriteArray(data)
if cache and chunks is None:
data = indexing.MemoryCachedArray(data)

result = DataArray(data=data, dims=('band', 'y', 'x'),
coords=coords, attrs=attrs)
coords=coords, attrs=attrs, encoding=encoding)
Copy link
Contributor

@snowman2 snowman2 Jul 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
coords=coords, attrs=attrs, encoding=encoding)
coords=coords, attrs=attrs)
result.encoding = encoding


if chunks is not None:
from dask.base import tokenize
Expand Down
25 changes: 25 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3130,6 +3130,31 @@ def test_platecarree(self):
np.testing.assert_array_equal(rioda.attrs['nodatavals'],
[-9765.])

def test_masked(self):
with create_tmp_geotiff(8, 10, 1, open_kwargs={'nodata': 3}) \
as (tmp_file, expected):
with xr.open_rasterio(tmp_file, masked=True) as rioda:
assert_allclose(rioda, expected.where(expected != 3))
np.testing.assert_array_equal(rioda.encoding['nodatavals'], [3.])
assert 'nodatavals' not in rioda.attrs.keys()
assert 3 not in rioda.data

# now write the dataarray to netcdf and read it back.
with create_tmp_file() as tmp_nc_file:
rioda.to_dataset(name='rio').to_netcdf(tmp_nc_file)
with xr.open_dataset(tmp_nc_file) as riods:
da = riods.rio
assert_allclose(da, expected.where(expected != 3))

def test_masked_is_false_by_default(self):
with create_tmp_geotiff(8, 10, 1, open_kwargs={'nodata': 3}) \
as (tmp_file, expected):
with xr.open_rasterio(tmp_file) as rioda:
assert_allclose(rioda, expected)
np.testing.assert_array_equal(rioda.attrs['nodatavals'], [3.])
assert 'nodatavals' not in rioda.encoding.keys()
assert 3 in rioda.data

def test_notransform(self):
# regression test for https://github.com/pydata/xarray/issues/1686
import rasterio
Expand Down