diff --git a/docs/source/conf.py b/docs/source/conf.py index 3d15a55b2..cdb224e1a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,9 +15,10 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. # import os -import re import sys +import sphinx_autosummary_accessors + sys.path.insert(0, os.path.abspath('../../')) # -- Project information ----------------------------------------------------- @@ -30,12 +31,12 @@ pkg = os.path.dirname(pkg) sys.path.append(pkg) -from sup3r._version import __version__ as v +import sup3r # The short X.Y version -version = re.search(r"^(\d+\.\d+)\.\d+(.dev\d+)?", v).group(0) +version = sup3r.__version__.split('+')[0] # The full version, including alpha/beta/rc tags -release = re.search(r"^(\d+\.\d+\.\d+(.dev\d+)?)", v).group(0) +release = sup3r.__version__.split('+')[0] # -- General configuration --------------------------------------------------- @@ -47,20 +48,21 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', 'sphinx.ext.coverage', - 'sphinx.ext.mathjax', 'sphinx.ext.viewcode', 'sphinx.ext.githubpages', - 'sphinx.ext.napoleon', - 'sphinx_rtd_theme', 'sphinx_click.ext', 'sphinx_tabs.tabs', - 'sphinx_copybutton', - "sphinx_rtd_dark_mode" + + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.extlinks", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx_autosummary_accessors", + "sphinx_copybutton", ] intersphinx_mapping = { @@ -68,7 +70,7 @@ } # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates", sphinx_autosummary_accessors.templates_path] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: @@ -106,14 +108,14 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = 'sphinx_book_theme' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # html_theme_options = {'navigation_depth': 4, 'collapse_navigation': False} -# html_css_file = ['custom.css'] +# html_css_files = ['custom.css'] # user starts in light mode default_dark_mode = False diff --git a/pyproject.toml b/pyproject.toml index 4a8de98ad..902c519f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -292,8 +292,7 @@ test = "pytest --pdb --durations=10 tests" [tool.pixi.feature.doc.dependencies] sphinx = ">=7.0" -sphinx_rtd_theme = ">=2.0" -sphinx-rtd-dark-mode = ">=1.3.0" +sphinx_book_theme = ">=1.1.3" [tool.pixi.feature.test.dependencies] pytest = ">=5.2" diff --git a/sup3r/bias/bias_calc.py b/sup3r/bias/bias_calc.py index 33b233a90..3697d9800 100644 --- a/sup3r/bias/bias_calc.py +++ b/sup3r/bias/bias_calc.py @@ -106,7 +106,7 @@ def _run_single( base_gid, base_handler, daily_reduction, - bias_ti, + bias_ti, # noqa: ARG003 decimals, base_dh_inst=None, match_zero_rate=False, diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index ac62ffd3f..2a76d5445 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -22,7 +22,7 @@ from sup3r.preprocessing.data_handlers import ExoData from sup3r.preprocessing.utilities import numpy_if_tensor from sup3r.utilities import VERSION_RECORD -from sup3r.utilities.utilities import Timer +from sup3r.utilities.utilities import Timer, safe_cast logger = logging.getLogger(__name__) @@ -1180,7 +1180,7 @@ def finish_epoch(self, if extras is not None: for k, v in extras.items(): - self._history.at[epoch, k] = v + self._history.at[epoch, k] = safe_cast(v) return stop diff --git a/sup3r/models/utilities.py b/sup3r/models/utilities.py index 4c5378fb2..1eca98e05 100644 --- a/sup3r/models/utilities.py +++ b/sup3r/models/utilities.py @@ -1,6 +1,7 @@ """Utilities shared across the `sup3r.models` module""" import logging +import sys import numpy as np from scipy.interpolate import RegularGridInterpolator @@ -8,6 +9,23 @@ logger = logging.getLogger(__name__) +def TrainingSession(model): + """Wrapper to gracefully exit batch handler thread during training, upon a + keyboard interruption.""" + + def wrapper(batch_handler, **kwargs): + """Wrap model.train().""" + try: + logger.info('Starting training session.') + model.train(batch_handler, **kwargs) + except KeyboardInterrupt: + logger.info('Ending training session.') + batch_handler.stop() + sys.exit() + + return wrapper + + def st_interp(low, s_enhance, t_enhance, t_centered=False): """Spatiotemporal bilinear interpolation for low resolution field on a regular grid. Used to provide baseline for comparison with gan output diff --git a/sup3r/postprocessing/collectors/nc.py b/sup3r/postprocessing/collectors/nc.py index f0be9d018..8dd01eb9a 100644 --- a/sup3r/postprocessing/collectors/nc.py +++ b/sup3r/postprocessing/collectors/nc.py @@ -1,4 +1,5 @@ """NETCDF file collection.""" + import logging import os import time @@ -7,6 +8,8 @@ from gaps import Status from rex.utilities.loggers import init_logger +from sup3r.preprocessing.utilities import _lowered + from .base import BaseCollector logger = logging.getLogger(__name__) @@ -20,13 +23,13 @@ def collect( cls, file_paths, out_file, - features, + features='all', log_level=None, log_file=None, write_status=False, job_name=None, overwrite=True, - res_kwargs=None + res_kwargs=None, ): """Collect data files from a dir to one output file. @@ -40,8 +43,9 @@ def collect( or a single string with unix-style /search/patt*ern.nc. out_file : str File path of final output file. - features : list - List of dsets to collect + features : list | str + List of dsets to collect. If 'all' then all ``data_vars`` will be + collected. log_level : str | None Desired log level, None will not initialize logging. log_file : str | None @@ -57,9 +61,7 @@ def collect( """ t0 = time.time() - logger.info( - f'Initializing collection for file_paths={file_paths}' - ) + logger.info(f'Initializing collection for file_paths={file_paths}') if log_level is not None: init_logger( @@ -80,10 +82,13 @@ def collect( if not os.path.exists(out_file): res_kwargs = res_kwargs or {} out = xr.open_mfdataset(collector.flist, **res_kwargs) - features = [feat for feat in out if feat in features - or feat.lower() in features] + features = list(out.data_vars) if features == 'all' else features + features = set(features).intersection(_lowered(out.data_vars)) for feat in features: - out[feat].to_netcdf(out_file, mode='a') + mode = 'a' if os.path.exists(out_file) else 'w' + out[feat].load().to_netcdf( + out_file, mode=mode, engine='h5netcdf', format='NETCDF4' + ) logger.info(f'Finished writing {feat} to {out_file}.') if write_status and job_name is not None: diff --git a/sup3r/postprocessing/writers/nc.py b/sup3r/postprocessing/writers/nc.py index 643e6b048..6332c12ac 100644 --- a/sup3r/postprocessing/writers/nc.py +++ b/sup3r/postprocessing/writers/nc.py @@ -126,5 +126,5 @@ def _write_output( meta_data=meta_data, max_workers=max_workers, gids=gids, - ).to_netcdf(out_file) + ).load().to_netcdf(out_file) logger.info(f'Saved output of size {data.shape} to: {out_file}') diff --git a/sup3r/preprocessing/accessor.py b/sup3r/preprocessing/accessor.py index 74c9379c0..444d3527d 100644 --- a/sup3r/preprocessing/accessor.py +++ b/sup3r/preprocessing/accessor.py @@ -308,9 +308,10 @@ def ordered(self, data): return data.transpose(*ordered_dims(data.dims), ...) def sample(self, idx): - """Get sample from self._ds. The idx should be a tuple of slices for - the dimensions (south_north, west_east, time) and a list of feature - names.""" + """Get sample from ``self._ds``. The idx should be a tuple of slices + for the dimensions ``(south_north, west_east, time)`` and a list of + feature names. e.g. + ``(slice(0, 3), slice(1, 10), slice(None), ['u_10m', 'v_10m'])``""" isel_kwargs = dict(zip(Dimension.dims_3d(), idx[:-1])) features = ( self.features if not is_strings(idx[-1]) else _lowered(idx[-1]) diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index b29b4092a..b81266fad 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -2,13 +2,17 @@ object, which just contains dataset objects. All objects that interact with data are containers. e.g. loaders, rasterizers, data handlers, samplers, batch queues, batch handlers. + +TODO: https://github.com/xarray-contrib/datatree might be a better approach +for Sup3rDataset concept. Consider migrating once datatree has been fully +integrated into xarray (in progress as of 8/8/2024) """ import logging import pprint from abc import ABCMeta from collections import namedtuple -from typing import Optional, Tuple, Union +from typing import Mapping, Tuple, Union from warnings import warn import numpy as np @@ -65,9 +69,9 @@ def __repr__(cls): class Sup3rDataset: """Interface for interacting with one or two ``xr.Dataset`` instances. - This is either a simple passthrough for a ``xr.Dataset`` instance or a - wrapper around two of them so they work well with Dual objects like - ``DualSampler``, ``DualRasterizer``, ``DualBatchHandler``, etc...) + This is a wrapper around one or two ``Sup3rX`` objects so they work well + with Dual objects like ``DualSampler``, ``DualRasterizer``, + ``DualBatchHandler``, etc...) Examples -------- @@ -103,95 +107,35 @@ class Sup3rDataset: def __init__( self, - data: Optional[ - Union[Tuple[xr.Dataset, ...], Tuple[Sup3rX, ...]] - ] = None, - **dsets: Union[xr.Dataset, Sup3rX], + **dsets: Mapping[str, Union[xr.Dataset, Sup3rX]], ): """ Parameters ---------- - data : Tuple[xr.Dataset | Sup3rX | Sup3rDataset] - ``Sup3rDataset`` will accomodate various types of data inputs, - which will ultimately be wrapped as a namedtuple of - :class:`~sup3r.preprocessing.Sup3rX` objects, stored in the - ``self._ds`` attribute. The preferred way to pass data here is - through dsets, which is a flexible **kwargs input. e.g. You can - provide ``name=data`` or ``name1=data1, name2=data2`` and these - names will be stored as attributes which point to that data. If - data is given as a tuple of :class:`~sup3r.preprocessing.Sup3rX` - objects then great, no prep needed. If given as a tuple of - ``xr.Dataset`` objects then each will be cast to ``Sup3rX`` - objects. If given as tuple of ``Sup3rDataset`` objects then we - make sure they contain only a single data member and use those to - initialize a new ``Sup3rDataset``. - - If the tuple here is a 1-tuple the namedtuple will use the name - "high_res" for the single dataset. If the tuple is a 2-tuple then - the first tuple member will be called "low_res" and the second - will be called "high_res". - - dsets : **dict[str, Union[xr.Dataset, Sup3rX]] - The preferred way to initialize a ``Sup3rDataset`` object, as a - dictionary with keys used to name a namedtuple of ``Sup3rX`` - objects. If dsets contains xr.Dataset objects these will be cast - to ``Sup3rX`` objects first. - + dsets : Mapping[str, xr.Dataset | Sup3rX | Sup3rDataset] + ``Sup3rDataset`` is initialized from a flexible kwargs input. The + keys will be used as names in a named tuple and the values will be + the dataset members. These names will also be used to define + attributes which point to these dataset members. You can provide + ``name=data`` or ``name1=data1, name2=data2`` and then access these + datasets as ``.name1`` or ``.name2``. If dsets values are + xr.Dataset objects these will be cast to ``Sup3rX`` objects first. + We also check if dsets values are ``Sup3rDataset`` objects and if + they only include one data member we use those to reinitialize a + ``Sup3rDataset`` """ - if data is not None: - data = data if isinstance(data, tuple) else (data,) - if all(isinstance(d, type(self)) for d in data): - msg = ( - 'Sup3rDataset received a tuple of Sup3rDataset objects' - ', each with two data members. If you insist on ' - 'initializing a Sup3rDataset with a tuple of the same, ' - 'then they have to be singletons.' - ) - assert all(len(d) == 1 for d in data), msg - msg = ( - 'Sup3rDataset received a tuple of Sup3rDataset ' - 'objects. You got away with it this time because they ' - 'each contain a single data member, but be careful' - ) - logger.warning(msg) - warn(msg) - if len(data) == 1: - msg = ( - f'{self.__class__.__name__} received a single data member ' - 'without an explicit name. Interpreting this as ' - '(high_res,). To be explicit provide keyword arguments ' - 'like Sup3rDataset(high_res=data[0])' - ) - logger.warning(msg) - warn(msg) - dsets = {'high_res': data[0]} - elif len(data) == 2: - msg = ( - f'{self.__class__.__name__} received a data tuple. ' - 'Interpreting this as (low_res, high_res). To be explicit ' - 'provide keyword arguments like ' - 'Sup3rDataset(low_res=data[0], high_res=data[1])' - ) - logger.warning(msg) - warn(msg) - dsets = {'low_res': data[0], 'high_res': data[1]} - else: + for name, dset in dsets.items(): + if isinstance(dset, xr.Dataset): + dsets[name] = Sup3rX(dset) + elif isinstance(dset, type(self)): msg = ( - f'{self.__class__.__name__} received tuple of length ' - f'{len(data)}. Can only handle 1 / 2 - tuples.' + 'Initializing Sup3rDataset with Sup3rDataset objects ' + 'which contain more than one member is not allowed.' ) - logger.error(msg) - raise ValueError(msg) - - dsets = { - k: Sup3rX(v) - if isinstance(v, xr.Dataset) - else v._ds[0] - if isinstance(v, type(self)) - else v - for k, v in dsets.items() - } + assert len(dset) == 1, msg + dsets[name] = dset._ds[0] + self._ds = namedtuple('Dataset', list(dsets))(**dsets) def __iter__(self): @@ -218,11 +162,7 @@ def __getattr__(self, attr): def _getattr(self, dset, attr): """Get attribute from single data member.""" - return ( - getattr(dset.sx, attr) - if hasattr(dset.sx, attr) - else getattr(dset, attr) - ) + return getattr(dset.sx, attr, getattr(dset, attr)) def _getitem(self, dset, item): """Get item from single data member.""" @@ -405,8 +345,7 @@ def data(self, data): :py:meth:`.wrap`""" self._data = self.wrap(data) - @staticmethod - def wrap(data): + def wrap(self, data): """ Return a :class:`~.Sup3rDataset` object or tuple of such. This is a tuple when the `.data` attribute belongs to a @@ -417,19 +356,27 @@ def wrap(data): 2-tuple when ``.data`` belongs to a dual container object like :class:`~.samplers.DualSampler` and a 1-tuple otherwise. """ - if isinstance(data, Sup3rDataset): + if data is None: return data - if isinstance(data, tuple) and all( - isinstance(d, Sup3rDataset) for d in data - ): + + check_sup3rds = all(isinstance(d, Sup3rDataset) for d in data) + check_sup3rds = check_sup3rds or isinstance(data, Sup3rDataset) + if check_sup3rds: return data - return ( - Sup3rDataset(low_res=data[0], high_res=data[1]) - if isinstance(data, tuple) and len(data) == 2 - else Sup3rDataset(high_res=data) - if data is not None and not isinstance(data, Sup3rDataset) - else data - ) + + if isinstance(data, tuple) and len(data) == 2: + msg = ( + f'{self.__class__.__name__}.data is being set with a ' + '2-tuple without explicit dataset names. We will assume ' + 'first tuple member is low-res and second is high-res.' + ) + logger.warning(msg) + warn(msg) + data = Sup3rDataset(low_res=data[0], high_res=data[1]) + elif not isinstance(data, Sup3rDataset): + name = getattr(data, 'name', None) or 'high_res' + data = Sup3rDataset(**{name: data}) + return data def post_init_log(self, args_dict=None): """Log additional arguments after initialization.""" diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index cdb63b867..31124de51 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -240,24 +240,19 @@ def enqueue_batches(self) -> None: """Callback function for queue thread. While training, the queue is checked for empty spots and filled. In the training thread, batches are removed from the queue.""" - try: - while self.running: - needed = self.queue_cap - self.queue.size().numpy() - needed = min((self.max_workers, needed)) - if needed == 1 or self.enqueue_pool is None: - self._enqueue_batch() - elif needed > 0: - futures = [ - self.enqueue_pool.submit(self._enqueue_batch) - for _ in np.arange(needed) - ] - logger.debug('Added %s enqueue futures.', needed) - for future in as_completed(futures): - _ = future.result() - - except KeyboardInterrupt: - logger.info(f'Stopping {self._thread_name.title()} queue.') - self.stop() + while self.running: + needed = self.queue_cap - self.queue.size().numpy() + needed = min((self.max_workers, needed)) + if needed == 1 or self.enqueue_pool is None: + self._enqueue_batch() + elif needed > 0: + futures = [ + self.enqueue_pool.submit(self._enqueue_batch) + for _ in np.arange(needed) + ] + logger.debug('Added %s enqueue futures.', needed) + for future in as_completed(futures): + _ = future.result() def __next__(self) -> Batch: """Dequeue batch samples, squeeze if for a spatial only model, perform diff --git a/sup3r/preprocessing/cachers/base.py b/sup3r/preprocessing/cachers/base.py index 4e8fab9f5..3caa1fc26 100644 --- a/sup3r/preprocessing/cachers/base.py +++ b/sup3r/preprocessing/cachers/base.py @@ -164,8 +164,8 @@ def write_h5( feature : str Name of feature to write to file. data : xr.DataArray - Data to write to file. Comes from self.data[feature], so an xarray - DataArray with dims and attributes + Data to write to file. Comes from ``self.data[feature]``, so an + xarray DataArray with dims and attributes coords : dict Dictionary of coordinate variables chunks : dict | None @@ -187,16 +187,8 @@ def write_h5( times = coords[Dimension.TIME].astype(int) for k, v in attrs.items(): f.attrs[k] = v - data_dict = dict( - zip( - [ - 'time_index', - *Dimension.coords_2d(), - feature, - ], - [da.asarray(times), lats, lons, data], - ) - ) + keys = ['time_index', *Dimension.coords_2d(), feature] + data_dict = dict(zip(keys, [da.asarray(times), lats, lons, data])) for dset, vals in data_dict.items(): f_chunks = chunks.get(dset, None) if dset in Dimension.coords_2d(): @@ -209,7 +201,7 @@ def write_h5( ) da.store(vals, d) logger.debug( - f'Added {dset} to {out_file} with chunks={f_chunks}' + 'Added %s to %s with chunks=%s', dset, out_file, f_chunks ) @classmethod @@ -225,13 +217,14 @@ def write_netcdf( feature : str Name of feature to write to file. data : xr.DataArray - Data to write to file. Comes from self.data[feature], so an xarray - DataArray with dims and attributes + Data to write to file. Comes from ``self.data[feature]``, so an + xarray DataArray with dims and attributes coords : dict | xr.Dataset.coords - Dictionary of coordinate variables or xr.Dataset coords attribute. + Dictionary of coordinate variables or ``xr.Dataset`` coords + attribute. chunks : dict | None - Chunk sizes for coordinate dimensions. e.g. {'windspeed': - {'south_north': 100, 'west_east': 100, 'time': 10}} + Chunk sizes for coordinate dimensions. e.g. ``{'windspeed': + {'south_north': 100, 'west_east': 100, 'time': 10}}`` attrs : dict | None Optional attributes to write to file """ @@ -243,4 +236,5 @@ def write_netcdf( attrs=attrs, ) out = out.chunk(chunks.get(feature, 'auto')) - out.to_netcdf(out_file) + out.load().to_netcdf(out_file) + del out diff --git a/sup3r/preprocessing/data_handlers/exo.py b/sup3r/preprocessing/data_handlers/exo.py index b38f7ed53..1a5871e24 100644 --- a/sup3r/preprocessing/data_handlers/exo.py +++ b/sup3r/preprocessing/data_handlers/exo.py @@ -31,9 +31,9 @@ def __init__(self, feature, combine_type, model, data): ``data`` will be used as input to the forward pass for the model step given by ``model`` model : int - Specifies the model index which will use the `data`. For example, + Specifies the model index which will use the ``data``. For example, if ``model`` == 1 then the ``data`` will be used according to - `combine_type` in the 2nd model step in a MultiStepGan. + ``combine_type`` in the 2nd model step in a MultiStepGan. data : Union[np.ndarray, da.core.Array] The data to be used for the given model step. """ diff --git a/sup3r/preprocessing/data_handlers/factory.py b/sup3r/preprocessing/data_handlers/factory.py index bc360a928..19428cb08 100644 --- a/sup3r/preprocessing/data_handlers/factory.py +++ b/sup3r/preprocessing/data_handlers/factory.py @@ -68,13 +68,15 @@ def __init__( features will be loaded. Specify explicit feature names for derivations. res_kwargs : dict - kwargs for the `BaseLoader`. BaseLoader is usually - xr.open_mfdataset for NETCDF files and MultiFileResourceX for H5 - files. + Additional keyword arguments passed through to the ``BaseLoader``. + BaseLoader is usually xr.open_mfdataset for NETCDF files and + MultiFileResourceX for H5 files. chunks : dict | str - Dictionary of chunk sizes to use for call to - `dask.array.from_array()` or `xr.Dataset().chunk()`. Will be - converted to a tuple when used in `from_array().` + Dictionary of chunk sizes to pass through to + ``dask.array.from_array()`` or ``xr.Dataset().chunk()``. Will be + converted to a tuple when used in ``from_array()``. These are the + methods for H5 and NETCDF data, respectively. This argument can + be "auto" in additional to a dictionary. target : tuple (lat, lon) lower left corner of raster. Either need target+shape or raster_file. diff --git a/sup3r/preprocessing/loaders/base.py b/sup3r/preprocessing/loaders/base.py index 1e0c133b3..2c455863b 100644 --- a/sup3r/preprocessing/loaders/base.py +++ b/sup3r/preprocessing/loaders/base.py @@ -45,13 +45,15 @@ def __init__( Features to return in loaded dataset. If 'all' then all available features will be returned. res_kwargs : dict - kwargs for the `BaseLoader`. BaseLoader is usually - xr.open_mfdataset for NETCDF files and MultiFileResourceX for H5 - files. + Additional keyword arguments passed through to the ``BaseLoader``. + BaseLoader is usually xr.open_mfdataset for NETCDF files and + MultiFileResourceX for H5 files. chunks : dict | str - Dictionary of chunk sizes to use for call to - `dask.array.from_array()` or `xr.Dataset().chunk()`. Will be - converted to a tuple when used in `from_array().` + Dictionary of chunk sizes to pass through to + ``dask.array.from_array()`` or ``xr.Dataset().chunk()``. Will be + converted to a tuple when used in ``from_array()``. These are the + methods for H5 and NETCDF data, respectively. This argument can + be "auto" in additional to a dictionary. BaseLoader : Callable Optional base loader update. The default for H5 files is MultiFileResourceX and for NETCDF is xarray.open_mfdataset diff --git a/sup3r/preprocessing/loaders/h5.py b/sup3r/preprocessing/loaders/h5.py index e62a9d162..f8e91c5c7 100644 --- a/sup3r/preprocessing/loaders/h5.py +++ b/sup3r/preprocessing/loaders/h5.py @@ -36,6 +36,12 @@ class LoaderH5(BaseLoader): def _time_independent(self): return 'time_index' not in self.res + @property + def _time_steps(self): + return ( + len(self.res['time_index']) if not self._time_independent else None + ) + def _meta_shape(self): """Get shape of spatial domain only.""" if 'latitude' in self.res.h5: @@ -57,7 +63,7 @@ def _res_shape(self): return ( self._meta_shape() if self._time_independent - else (len(self.res['time_index']), *self._meta_shape()) + else (self._time_steps, *self._meta_shape()) ) def _get_coords(self, dims): @@ -68,18 +74,10 @@ def _get_coords(self, dims): coord_base = ( self.res.h5 if 'latitude' in self.res.h5 else self.res.h5['meta'] ) - coords.update( - { - Dimension.LATITUDE: ( - dims[-len(self._meta_shape()) :], - da.from_array(coord_base['latitude']), - ), - Dimension.LONGITUDE: ( - dims[-len(self._meta_shape()) :], - da.from_array(coord_base['longitude']), - ), - } - ) + coord_dims = dims[-len(self._meta_shape()) :] + lats = (coord_dims, da.from_array(coord_base['latitude'])) + lons = (coord_dims, da.from_array(coord_base['longitude'])) + coords.update({Dimension.LATITUDE: lats, Dimension.LONGITUDE: lons}) return coords def _get_dset_tuple(self, dset, dims, chunks): @@ -114,10 +112,8 @@ def _get_dset_tuple(self, dset, dims, chunks): f'Received 1D feature "{dset}" with shape that does not ' 'the length of the meta nor the time_index.' ) - assert ( - not self._time_independent - and len(arr) == self.res['time_index'] - ), msg + is_ts = not self._time_independent and len(arr) == self._time_steps + assert is_ts, msg arr_dims = (Dimension.TIME,) else: arr_dims = dims[: len(arr.shape)] @@ -125,7 +121,7 @@ def _get_dset_tuple(self, dset, dims, chunks): def _get_data_vars(self, dims): """Define data_vars dict for xr.Dataset construction.""" - data_vars: Dict[str, Tuple] = {} + data_vars = {} logger.debug(f'Rechunking features with chunks: {self.chunks}') chunks = ( tuple(self.chunks[d] for d in dims) @@ -133,33 +129,26 @@ def _get_data_vars(self, dims): else self.chunks ) if len(self._meta_shape()) == 1 and 'elevation' in self.res.meta: - data_vars['elevation'] = da.asarray( - self.res.meta['elevation'].values.astype(np.float32) - ) + elev = self.res.meta['elevation'].values.astype(np.float32) + elev = da.asarray(elev) if not self._time_independent: - data_vars['elevation'] = da.repeat( - data_vars['elevation'][None, ...], - len(self.res['time_index']), - axis=0, - ) - data_vars['elevation'] = data_vars['elevation'].rechunk(chunks) - data_vars['elevation'] = (dims, data_vars['elevation']) - data_vars.update( - { - f: self._get_dset_tuple(dset=f, dims=dims, chunks=chunks) - for f in set(self.res.h5.datasets) - - {'meta', 'time_index', 'coordinates'} - } - ) + t_steps = len(self.res['time_index']) + elev = da.repeat(elev[None, ...], t_steps, axis=0) + elev = elev.rechunk(chunks) + data_vars['elevation'] = (dims, elev) + + feats = set(self.res.h5.datasets) + exclude = {'meta', 'time_index', 'coordinates'} + for f in feats - exclude: + data_vars[f] = self._get_dset_tuple( + dset=f, dims=dims, chunks=chunks + ) return data_vars def _get_dims(self): """Get tuple of named dims for dataset.""" if len(self._meta_shape()) == 2: - dims: Tuple[str, ...] = ( - Dimension.SOUTH_NORTH, - Dimension.WEST_EAST, - ) + dims = Dimension.dims_2d() else: dims = (Dimension.FLATTENED_SPATIAL,) if not self._time_independent: @@ -176,7 +165,6 @@ def load(self) -> xr.Dataset: for k, v in self._get_data_vars(dims).items() if k not in coords } - data_vars = {k: v for k, v in data_vars.items() if k not in coords} return xr.Dataset(coords=coords, data_vars=data_vars).astype( np.float32 ) @@ -184,7 +172,7 @@ def load(self) -> xr.Dataset: def scale_factor(self, feature): """Get scale factor for given feature. Data is stored in scaled form to reduce memory.""" - feat = feature if feature in self.res else feature.lower() + feat = feature if feature in self.res.datasets else feature.lower() feat = self.res.h5[feat] return np.float32( 1.0 diff --git a/sup3r/preprocessing/loaders/nc.py b/sup3r/preprocessing/loaders/nc.py index f6ba8da55..733d4e9c5 100644 --- a/sup3r/preprocessing/loaders/nc.py +++ b/sup3r/preprocessing/loaders/nc.py @@ -3,6 +3,7 @@ classes.""" import logging +from warnings import warn import dask.array as da import numpy as np @@ -73,6 +74,8 @@ def get_coords(res): lats = res[Dimension.LATITUDE].data.squeeze().astype(np.float32) lons = res[Dimension.LONGITUDE].data.squeeze().astype(np.float32) + res.swap_dims({}) + if len(lats.shape) == 1: lons, lats = da.meshgrid(lons, lats) @@ -93,6 +96,26 @@ def get_coords(res): coords[Dimension.TIME] = times return coords + @staticmethod + def get_dims(res): + """Get dimension name map using our standard mappping and the names + used for coordinate dimensions.""" + rename_dims = {k: v for k, v in DIM_NAMES.items() if k in res.dims} + lat_dims = res[Dimension.LATITUDE].dims + lon_dims = res[Dimension.LONGITUDE].dims + if len(lat_dims) == 1 and len(lon_dims) == 1: + rename_dims[lat_dims[0]] = Dimension.SOUTH_NORTH + rename_dims[lon_dims[0]] = Dimension.WEST_EAST + else: + msg = ('2D Latitude and Longitude dimension names are different. ' + 'This is weird.') + if lon_dims != lat_dims: + logger.warning(msg) + warn(msg) + else: + rename_dims.update(dict(zip(lat_dims, Dimension.dims_2d()))) + return rename_dims + def load(self): """Load netcdf xarray.Dataset().""" res = lower_names(self.res) @@ -100,14 +123,13 @@ def load(self): k: v for k, v in COORD_NAMES.items() if k in res and v not in res } res = res.rename(rename_coords) - rename_dims = {k: v for k, v in DIM_NAMES.items() if k in res.dims} - res = res.swap_dims(rename_dims) if not all(coord in res for coord in Dimension.coords_2d()): err = 'Could not find valid coordinates in given files: %s' logger.error(err, self.file_paths) raise OSError(err % (self.file_paths)) + res = res.swap_dims(self.get_dims(res)) res = res.assign_coords(self.get_coords(res)) if isinstance(self.chunks, dict): res = res.chunk(self.chunks) diff --git a/sup3r/preprocessing/rasterizers/exo.py b/sup3r/preprocessing/rasterizers/exo.py index 59a21c5a9..4e085fabb 100644 --- a/sup3r/preprocessing/rasterizers/exo.py +++ b/sup3r/preprocessing/rasterizers/exo.py @@ -265,7 +265,7 @@ def data(self): if not os.path.exists(cache_fp): tmp_fp = cache_fp + f'{generate_random_string(10)}.tmp' - data.to_netcdf(tmp_fp) + data.load().to_netcdf(tmp_fp) shutil.move(tmp_fp, cache_fp) return data diff --git a/sup3r/preprocessing/rasterizers/extended.py b/sup3r/preprocessing/rasterizers/extended.py index eb938fc6a..cbe7aa88d 100644 --- a/sup3r/preprocessing/rasterizers/extended.py +++ b/sup3r/preprocessing/rasterizers/extended.py @@ -42,13 +42,15 @@ def __init__( Features to return in loaded dataset. If 'all' then all available features will be returned. res_kwargs : dict - kwargs for the `BaseLoader`. BaseLoader is usually - xr.open_mfdataset for NETCDF files and MultiFileResourceX for H5 - files. + Additional keyword arguments passed through to the ``BaseLoader``. + BaseLoader is usually xr.open_mfdataset for NETCDF files and + MultiFileResourceX for H5 files. chunks : dict | str - Dictionary of chunk sizes to use for call to - `dask.array.from_array()` or xr.Dataset().chunk(). Will be - converted to a tuple when used in `from_array().` + Dictionary of chunk sizes to pass through to + ``dask.array.from_array()`` or ``xr.Dataset().chunk()``. Will be + converted to a tuple when used in ``from_array()``. These are the + methods for H5 and NETCDF data, respectively. This argument can + be "auto" in additional to a dictionary. target : tuple (lat, lon) lower left corner of raster. Either need target+shape or raster_file. diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index 6bb2ec006..8a92e339b 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -347,7 +347,7 @@ def process_surface_file(self): ds = self.convert_z(ds, name='orog') ds = standardize_names(ds, ERA_NAME_MAP) ds = standardize_values(ds) - ds.to_netcdf(tmp_file) + ds.load().to_netcdf(tmp_file) os.replace(tmp_file, self.surface_file) logger.info( f'Finished processing {self.surface_file}. Moved {tmp_file} to ' @@ -406,7 +406,7 @@ def process_level_file(self): ds = standardize_names(ds, ERA_NAME_MAP) ds = standardize_values(ds) ds = self.add_pressure(ds) - ds.to_netcdf(tmp_file) + ds.load().to_netcdf(tmp_file) os.replace(tmp_file, self.level_file) logger.info( f'Finished processing {self.level_file}. Moved ' @@ -424,7 +424,7 @@ def _write_dsets(cls, files, out_file, kwargs=None): for f in set(ds.data_vars) - set(added_features): mode = 'w' if not os.path.exists(tmp_file) else 'a' logger.info('Adding %s to %s.', f, tmp_file) - ds.to_netcdf(tmp_file, mode=mode) + ds.load().to_netcdf(tmp_file, mode=mode) logger.info('Added %s to %s.', f, tmp_file) added_features.append(f) logger.info(f'Finished writing {tmp_file}') diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index a35fc78a0..6210fefe3 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -16,17 +16,20 @@ RANDOM_GENERATOR = np.random.default_rng(seed=42) +def safe_cast(o): + """Cast to type safe for serialization.""" + if isinstance(o, (float, np.float64, np.float32)): + return float(o) + if isinstance(o, (int, np.int64, np.int32)): + return int(o) + if isinstance(o, (tuple, np.ndarray)): + return list(o) + return str(o) + + def safe_serialize(obj): """json.dumps with non-serializable object handling.""" - def _default(o): - if isinstance(o, (np.float64, np.float32)): - return float(o) - if isinstance(o, (np.int64, np.int32)): - return int(o) - if isinstance(o, (tuple, np.ndarray)): - return list(o) - return str(o) - return json.dumps(obj, default=_default) + return json.dumps(obj, default=safe_cast) class Timer: diff --git a/tests/data_wrapper/test_access.py b/tests/data_wrapper/test_access.py index 4f43d09a1..79b63284c 100644 --- a/tests/data_wrapper/test_access.py +++ b/tests/data_wrapper/test_access.py @@ -93,10 +93,8 @@ def test_correct_single_member_access(data): def test_correct_multi_member_access(): """Make sure Data object works correctly.""" data = Sup3rDataset( - ( - Sup3rX(make_fake_dset((20, 20, 100, 3), features=['u', 'v'])), - Sup3rX(make_fake_dset((20, 20, 100, 3), features=['u', 'v'])), - ) + first=Sup3rX(make_fake_dset((20, 20, 100, 3), features=['u', 'v'])), + second=Sup3rX(make_fake_dset((20, 20, 100, 3), features=['u', 'v'])) ) _ = data['u']