diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 2f00702f854..0261961d65b 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -23,7 +23,7 @@ class RasterioArrayWrapper(BackendArray): """A wrapper around rasterio dataset objects""" - def __init__(self, manager, lock, vrt_params=None): + def __init__(self, manager, lock, vrt_params=None, masked=False): from rasterio.vrt import WarpedVRT self.manager = manager self.lock = lock @@ -39,6 +39,7 @@ def __init__(self, manager, lock, vrt_params=None): if not np.all(np.asarray(dtypes) == dtypes[0]): raise ValueError('All bands should have the same dtype') self._dtype = np.dtype(dtypes[0]) + self.masked = masked @property def dtype(self): @@ -119,7 +120,9 @@ def _getitem(self, key): riods = self.manager.acquire(needs_lock=False) if self.vrt_params is not None: riods = WarpedVRT(riods, **self.vrt_params) - out = riods.read(band_key, window=window) + out = riods.read(band_key, window=window, masked=self.masked) + if self.masked: + out = np.ma.filled(out, np.nan) if squeeze_axis: out = np.squeeze(out, axis=squeeze_axis) @@ -162,7 +165,7 @@ def default(s): def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, - lock=None): + lock=None, masked=False): """Open a file with rasterio (experimental). This should work with any file that rasterio can open (most often: @@ -205,6 +208,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, :py:func:`dask.array.from_array`. By default, a global lock is used to avoid issues with concurrent access to the same file when using dask's multithreaded backend. + masked : bool, optional + If True, uses nodatavals to set values to NaN. Defaults to False. Returns ------- @@ -275,6 +280,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, # Attributes attrs = dict() + encoding = dict() # Affine transformation matrix (always available) # This describes coefficients mapping pixel coordinates to CRS # For serialization store as tuple of 6 floats, the last row being @@ -298,9 +304,13 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, attrs['is_tiled'] = np.uint8(riods.is_tiled) if hasattr(riods, 'nodatavals'): # The nodata values for the raster bands - attrs['nodatavals'] = tuple( + nodatavals = tuple( np.nan if nodataval is None else nodataval for nodataval in riods.nodatavals) + if masked: + encoding['nodatavals'] = nodatavals + else: + attrs['nodatavals'] = nodatavals # Parse extra metadata from tags, if supported parsers = {'ENVI': _parse_envi} @@ -319,7 +329,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, attrs[k] = v data = indexing.LazilyOuterIndexedArray( - RasterioArrayWrapper(manager, lock, vrt_params)) + RasterioArrayWrapper(manager, lock, vrt_params, masked=masked)) # this lets you write arrays loaded with rasterio data = indexing.CopyOnWriteArray(data) @@ -327,7 +337,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, data = indexing.MemoryCachedArray(data) result = DataArray(data=data, dims=('band', 'y', 'x'), - coords=coords, attrs=attrs) + coords=coords, attrs=attrs, encoding=encoding) if chunks is not None: from dask.base import tokenize diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 5e28ff46665..6b4e67575fa 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -3130,6 +3130,31 @@ def test_platecarree(self): np.testing.assert_array_equal(rioda.attrs['nodatavals'], [-9765.]) + def test_masked(self): + with create_tmp_geotiff(8, 10, 1, open_kwargs={'nodata': 3}) \ + as (tmp_file, expected): + with xr.open_rasterio(tmp_file, masked=True) as rioda: + assert_allclose(rioda, expected.where(expected != 3)) + np.testing.assert_array_equal(rioda.encoding['nodatavals'], [3.]) + assert 'nodatavals' not in rioda.attrs.keys() + assert 3 not in rioda.data + + # now write the dataarray to netcdf and read it back. + with create_tmp_file() as tmp_nc_file: + rioda.to_dataset(name='rio').to_netcdf(tmp_nc_file) + with xr.open_dataset(tmp_nc_file) as riods: + da = riods.rio + assert_allclose(da, expected.where(expected != 3)) + + def test_masked_is_false_by_default(self): + with create_tmp_geotiff(8, 10, 1, open_kwargs={'nodata': 3}) \ + as (tmp_file, expected): + with xr.open_rasterio(tmp_file) as rioda: + assert_allclose(rioda, expected) + np.testing.assert_array_equal(rioda.attrs['nodatavals'], [3.]) + assert 'nodatavals' not in rioda.encoding.keys() + assert 3 in rioda.data + def test_notransform(self): # regression test for https://github.com/pydata/xarray/issues/1686 import rasterio