diff --git a/xray/backends/api.py b/xray/backends/api.py index afc85a4d716..b4084bb71ff 100644 --- a/xray/backends/api.py +++ b/xray/backends/api.py @@ -39,7 +39,7 @@ def _get_default_engine(path, allow_remote=False): def open_dataset(filename_or_obj, group=None, decode_cf=True, mask_and_scale=True, decode_times=True, concat_characters=True, decode_coords=True, engine=None, - chunks=None, lock=True): + chunks=None, lock=True, close_files=False): """Load and decode a dataset from a file or file-like object. Parameters @@ -84,6 +84,9 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, If chunks is provided, this argument is passed on to :py:func:`dask.array.from_array`. By default, a lock is used to avoid issues with concurrent access with dask's multithreaded backend. + close_files: bool, optional + If True, the engine should try to keep the underlying files closed. + Currently only supported for engine='scipy'. Returns ------- @@ -108,6 +111,9 @@ def maybe_decode_store(store): ds = ds.chunk(chunks, lock=lock) return ds + if close_files and engine != 'scipy': + raise ValueError('close_files only supported for scipy engine') + if isinstance(filename_or_obj, backends.AbstractDataStore): store = filename_or_obj elif isinstance(filename_or_obj, basestring): @@ -137,7 +143,7 @@ def maybe_decode_store(store): if engine == 'netcdf4': store = backends.NetCDF4DataStore(filename_or_obj, group=group) elif engine == 'scipy': - store = backends.ScipyDataStore(filename_or_obj) + store = backends.ScipyDataStore(filename_or_obj, close_files=close_files) elif engine == 'pydap': store = backends.PydapDataStore(filename_or_obj) elif engine == 'h5netcdf': @@ -147,7 +153,10 @@ def maybe_decode_store(store): % engine) with close_on_error(store): - return maybe_decode_store(store) + ds = maybe_decode_store(store) + if close_files: + store.ds.close() + return ds else: if engine is not None and engine != 'scipy': raise ValueError('can only read file-like objects with ' diff --git a/xray/backends/scipy_.py b/xray/backends/scipy_.py index 872c55fe2ce..7eb801671a2 100644 --- a/xray/backends/scipy_.py +++ b/xray/backends/scipy_.py @@ -3,6 +3,8 @@ import numpy as np import warnings +import scipy.io + from .. import Variable from ..conventions import cf_encoder from ..core.pycompat import iteritems, basestring, unicode_type, OrderedDict @@ -28,6 +30,33 @@ def _decode_attrs(d): for (k, v) in iteritems(d)) +class ScipyNetcdfWrapper(scipy.io.netcdf_file): + # wrap scipy.io.netcdf_file object in a class that + # can open or close the underlying file as needed + def __init__(self, *args, **kwargs): + self._is_open = True + self._variables = dict() + self._init_args = args + self._init_kwargs = kwargs.copy() + self.close_files = kwargs.pop('close_files') + super(ScipyNetcdfWrapper, self).__init__(*args, **kwargs) + + def close(self): + super(ScipyNetcdfWrapper, self).close() + self._is_open = False + + @property + def variables(self): + # check to see if the underlying file is closed + if not self._is_open: + # if so, reinitialize it + self.__init__(*self._init_args, **self._init_kwargs) + return self._variables + + @variables.setter + def variables(self, value): + self._variables = value + class ScipyArrayWrapper(NumpyIndexingAdapter): def __init__(self, netcdf_file, variable_name): self.netcdf_file = netcdf_file @@ -39,7 +68,10 @@ def array(self): # because otherwise scipy complains about variables or files still # referencing mmapped arrays when we try to close datasets without # having read all data in the file. - return self.netcdf_file.variables[self.variable_name].data + a = self.netcdf_file.variables[self.variable_name].data + if self.netcdf_file.close_files: + self.netcdf_file.close() + return a @property def dtype(self): @@ -53,6 +85,8 @@ def __getitem__(self, key): # after closing associated files. copy = self.netcdf_file.use_mmap data = np.array(data, dtype=self.dtype, copy=copy) + if self.netcdf_file.close_files: + self.netcdf_file.close() return data @@ -65,9 +99,8 @@ class ScipyDataStore(AbstractWritableDataStore): It only supports the NetCDF3 file-format. """ def __init__(self, filename_or_obj, mode='r', format=None, group=None, - writer=None, mmap=None): + writer=None, mmap=None, close_files=False): import scipy - import scipy.io if mode != 'r' and scipy.__version__ < '0.13': # pragma: no cover warnings.warn('scipy %s detected; ' 'the minimal recommended version is 0.13. ' @@ -87,16 +120,25 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, raise ValueError('invalid format for scipy.io.netcdf backend: %r' % format) + # can't use mmap with close_files + if close_files: + mmap = False + # if filename is a NetCDF3 bytestring we store it in a StringIO if (isinstance(filename_or_obj, basestring) and filename_or_obj.startswith('CDF')): # TODO: this check has the unfortunate side-effect that # paths to files cannot start with 'CDF'. filename_or_obj = BytesIO(filename_or_obj) - self.ds = scipy.io.netcdf_file( - filename_or_obj, mode=mode, mmap=mmap, version=version) + self.ds = ScipyNetcdfWrapper( + filename_or_obj, mode=mode, mmap=mmap, version=version, + close_files=close_files) super(ScipyDataStore, self).__init__(writer) + if close_files: + self.ds.close() + + def store(self, variables, attributes): # All Scipy objects get CF encoded by default, without this attempting # to write times, for example, would fail.