8
8
from .. import DataArray
9
9
from ..core import indexing
10
10
from ..core .utils import is_scalar
11
- from .common import BackendArray
11
+ from .common import BackendArray , PickleByReconstructionWrapper
12
12
13
13
try :
14
14
from dask .utils import SerializableLock as Lock
25
25
class RasterioArrayWrapper (BackendArray ):
26
26
"""A wrapper around rasterio dataset objects"""
27
27
28
- def __init__ (self , rasterio_ds ):
29
- self .rasterio_ds = rasterio_ds
30
- self ._shape = (rasterio_ds . count , rasterio_ds .height ,
31
- rasterio_ds .width )
28
+ def __init__ (self , riods ):
29
+ self .riods = riods
30
+ self ._shape = (riods . value . count , riods . value .height ,
31
+ riods . value .width )
32
32
self ._ndims = len (self .shape )
33
33
34
34
@property
35
35
def dtype (self ):
36
- dtypes = self .rasterio_ds .dtypes
36
+ dtypes = self .riods . value .dtypes
37
37
if not np .all (np .asarray (dtypes ) == dtypes [0 ]):
38
38
raise ValueError ('All bands should have the same dtype' )
39
39
return np .dtype (dtypes [0 ])
@@ -105,7 +105,7 @@ def _get_indexer(self, key):
105
105
def __getitem__ (self , key ):
106
106
band_key , window , squeeze_axis , np_inds = self ._get_indexer (key )
107
107
108
- out = self .rasterio_ds .read (band_key , window = tuple (window ))
108
+ out = self .riods . value .read (band_key , window = tuple (window ))
109
109
if squeeze_axis :
110
110
out = np .squeeze (out , axis = squeeze_axis )
111
111
return indexing .NumpyIndexingAdapter (out )[np_inds ]
@@ -194,28 +194,29 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
194
194
"""
195
195
196
196
import rasterio
197
- riods = rasterio .open (filename , mode = 'r' )
197
+
198
+ riods = PickleByReconstructionWrapper (rasterio .open , filename , mode = 'r' )
198
199
199
200
if cache is None :
200
201
cache = chunks is None
201
202
202
203
coords = OrderedDict ()
203
204
204
205
# Get bands
205
- if riods .count < 1 :
206
+ if riods .value . count < 1 :
206
207
raise ValueError ('Unknown dims' )
207
- coords ['band' ] = np .asarray (riods .indexes )
208
+ coords ['band' ] = np .asarray (riods .value . indexes )
208
209
209
210
# Get coordinates
210
211
if LooseVersion (rasterio .__version__ ) < '1.0' :
211
- transform = riods .affine
212
+ transform = riods .value . affine
212
213
else :
213
- transform = riods .transform
214
+ transform = riods .value . transform
214
215
if transform .is_rectilinear :
215
216
# 1d coordinates
216
217
parse = True if parse_coordinates is None else parse_coordinates
217
218
if parse :
218
- nx , ny = riods .width , riods .height
219
+ nx , ny = riods .value . width , riods . value .height
219
220
# xarray coordinates are pixel centered
220
221
x , _ = (np .arange (nx ) + 0.5 , np .zeros (nx ) + 0.5 ) * transform
221
222
_ , y = (np .zeros (ny ) + 0.5 , np .arange (ny ) + 0.5 ) * transform
@@ -238,41 +239,42 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
238
239
# For serialization store as tuple of 6 floats, the last row being
239
240
# always (0, 0, 1) per definition (see https://github.com/sgillies/affine)
240
241
attrs ['transform' ] = tuple (transform )[:6 ]
241
- if hasattr (riods , 'crs' ) and riods .crs :
242
+ if hasattr (riods . value , 'crs' ) and riods . value .crs :
242
243
# CRS is a dict-like object specific to rasterio
243
244
# If CRS is not None, we convert it back to a PROJ4 string using
244
245
# rasterio itself
245
- attrs ['crs' ] = riods .crs .to_string ()
246
- if hasattr (riods , 'res' ):
246
+ attrs ['crs' ] = riods .value . crs .to_string ()
247
+ if hasattr (riods . value , 'res' ):
247
248
# (width, height) tuple of pixels in units of CRS
248
- attrs ['res' ] = riods .res
249
- if hasattr (riods , 'is_tiled' ):
249
+ attrs ['res' ] = riods .value . res
250
+ if hasattr (riods . value , 'is_tiled' ):
250
251
# Is the TIF tiled? (bool)
251
252
# We cast it to an int for netCDF compatibility
252
- attrs ['is_tiled' ] = np .uint8 (riods .is_tiled )
253
+ attrs ['is_tiled' ] = np .uint8 (riods .value . is_tiled )
253
254
with warnings .catch_warnings ():
254
- # casting riods.transform to a tuple makes this future proof
255
+ # casting riods.value. transform to a tuple makes this future proof
255
256
warnings .simplefilter ('ignore' , FutureWarning )
256
- if hasattr (riods , 'transform' ):
257
+ if hasattr (riods . value , 'transform' ):
257
258
# Affine transformation matrix (tuple of floats)
258
259
# Describes coefficients mapping pixel coordinates to CRS
259
- attrs ['transform' ] = tuple (riods .transform )
260
- if hasattr (riods , 'nodatavals' ):
260
+ attrs ['transform' ] = tuple (riods .value . transform )
261
+ if hasattr (riods . value , 'nodatavals' ):
261
262
# The nodata values for the raster bands
262
263
attrs ['nodatavals' ] = tuple ([np .nan if nodataval is None else nodataval
263
- for nodataval in riods .nodatavals ])
264
+ for nodataval in riods .value . nodatavals ])
264
265
265
266
# Parse extra metadata from tags, if supported
266
267
parsers = {'ENVI' : _parse_envi }
267
268
268
- driver = riods .driver
269
+ driver = riods .value . driver
269
270
if driver in parsers :
270
- meta = parsers [driver ](riods .tags (ns = driver ))
271
+ meta = parsers [driver ](riods .value . tags (ns = driver ))
271
272
272
273
for k , v in meta .items ():
273
274
# Add values as coordinates if they match the band count,
274
275
# as attributes otherwise
275
- if isinstance (v , (list , np .ndarray )) and len (v ) == riods .count :
276
+ if (isinstance (v , (list , np .ndarray )) and
277
+ len (v ) == riods .value .count ):
276
278
coords [k ] = ('band' , np .asarray (v ))
277
279
else :
278
280
attrs [k ] = v
0 commit comments