Skip to content

Commit

Permalink
simplified args for Sup3rDataset. Single check for tuples in `Conta…
Browse files Browse the repository at this point in the history
…iner` data setter. sphinx_book_theme.
  • Loading branch information
bnb32 committed Aug 8, 2024
1 parent 47d1fe7 commit d8e0fe2
Show file tree
Hide file tree
Showing 21 changed files with 232 additions and 254 deletions.
32 changes: 17 additions & 15 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import re
import sys

import sphinx_autosummary_accessors

sys.path.insert(0, os.path.abspath('../../'))

# -- Project information -----------------------------------------------------
Expand All @@ -30,12 +31,12 @@
pkg = os.path.dirname(pkg)
sys.path.append(pkg)

from sup3r._version import __version__ as v
import sup3r

# The short X.Y version
version = re.search(r"^(\d+\.\d+)\.\d+(.dev\d+)?", v).group(0)
version = sup3r.__version__.split('+')[0]
# The full version, including alpha/beta/rc tags
release = re.search(r"^(\d+\.\d+\.\d+(.dev\d+)?)", v).group(0)
release = sup3r.__version__.split('+')[0]

# -- General configuration ---------------------------------------------------

Expand All @@ -47,28 +48,29 @@
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.viewcode',
'sphinx.ext.githubpages',
'sphinx.ext.napoleon',
'sphinx_rtd_theme',
'sphinx_click.ext',
'sphinx_tabs.tabs',
'sphinx_copybutton',
"sphinx_rtd_dark_mode"

"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.extlinks",
"sphinx.ext.mathjax",
"sphinx.ext.napoleon",
"sphinx_autosummary_accessors",
"sphinx_copybutton",
]

intersphinx_mapping = {
'python': ('https://docs.python.org/3/', None),
}

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates", sphinx_autosummary_accessors.templates_path]

# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
Expand Down Expand Up @@ -106,14 +108,14 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
html_theme = 'sphinx_book_theme'

# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
html_theme_options = {'navigation_depth': 4, 'collapse_navigation': False}
# html_css_file = ['custom.css']
# html_css_files = ['custom.css']

# user starts in light mode
default_dark_mode = False
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,7 @@ test = "pytest --pdb --durations=10 tests"

[tool.pixi.feature.doc.dependencies]
sphinx = ">=7.0"
sphinx_rtd_theme = ">=2.0"
sphinx-rtd-dark-mode = ">=1.3.0"
sphinx_book_theme = ">=1.1.3"

[tool.pixi.feature.test.dependencies]
pytest = ">=5.2"
Expand Down
2 changes: 1 addition & 1 deletion sup3r/bias/bias_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _run_single(
base_gid,
base_handler,
daily_reduction,
bias_ti,
bias_ti, # noqa: ARG003
decimals,
base_dh_inst=None,
match_zero_rate=False,
Expand Down
4 changes: 2 additions & 2 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sup3r.preprocessing.data_handlers import ExoData
from sup3r.preprocessing.utilities import numpy_if_tensor
from sup3r.utilities import VERSION_RECORD
from sup3r.utilities.utilities import Timer
from sup3r.utilities.utilities import Timer, safe_cast

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1180,7 +1180,7 @@ def finish_epoch(self,

if extras is not None:
for k, v in extras.items():
self._history.at[epoch, k] = v
self._history.at[epoch, k] = safe_cast(v)

return stop

Expand Down
18 changes: 18 additions & 0 deletions sup3r/models/utilities.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
"""Utilities shared across the `sup3r.models` module"""

import logging
import sys

import numpy as np
from scipy.interpolate import RegularGridInterpolator

logger = logging.getLogger(__name__)


def TrainingSession(model):
"""Wrapper to gracefully exit batch handler thread during training, upon a
keyboard interruption."""

def wrapper(batch_handler, **kwargs):
"""Wrap model.train()."""
try:
logger.info('Starting training session.')
model.train(batch_handler, **kwargs)
except KeyboardInterrupt:
logger.info('Ending training session.')
batch_handler.stop()
sys.exit()

return wrapper


def st_interp(low, s_enhance, t_enhance, t_centered=False):
"""Spatiotemporal bilinear interpolation for low resolution field on a
regular grid. Used to provide baseline for comparison with gan output
Expand Down
25 changes: 15 additions & 10 deletions sup3r/postprocessing/collectors/nc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""NETCDF file collection."""

import logging
import os
import time
Expand All @@ -7,6 +8,8 @@
from gaps import Status
from rex.utilities.loggers import init_logger

from sup3r.preprocessing.utilities import _lowered

from .base import BaseCollector

logger = logging.getLogger(__name__)
Expand All @@ -20,13 +23,13 @@ def collect(
cls,
file_paths,
out_file,
features,
features='all',
log_level=None,
log_file=None,
write_status=False,
job_name=None,
overwrite=True,
res_kwargs=None
res_kwargs=None,
):
"""Collect data files from a dir to one output file.
Expand All @@ -40,8 +43,9 @@ def collect(
or a single string with unix-style /search/patt*ern.nc.
out_file : str
File path of final output file.
features : list
List of dsets to collect
features : list | str
List of dsets to collect. If 'all' then all ``data_vars`` will be
collected.
log_level : str | None
Desired log level, None will not initialize logging.
log_file : str | None
Expand All @@ -57,9 +61,7 @@ def collect(
"""
t0 = time.time()

logger.info(
f'Initializing collection for file_paths={file_paths}'
)
logger.info(f'Initializing collection for file_paths={file_paths}')

if log_level is not None:
init_logger(
Expand All @@ -80,10 +82,13 @@ def collect(
if not os.path.exists(out_file):
res_kwargs = res_kwargs or {}
out = xr.open_mfdataset(collector.flist, **res_kwargs)
features = [feat for feat in out if feat in features
or feat.lower() in features]
features = list(out.data_vars) if features == 'all' else features
features = set(features).intersection(_lowered(out.data_vars))
for feat in features:
out[feat].to_netcdf(out_file, mode='a')
mode = 'a' if os.path.exists(out_file) else 'w'
out[feat].load().to_netcdf(
out_file, mode=mode, engine='h5netcdf', format='NETCDF4'
)
logger.info(f'Finished writing {feat} to {out_file}.')

if write_status and job_name is not None:
Expand Down
2 changes: 1 addition & 1 deletion sup3r/postprocessing/writers/nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,5 @@ def _write_output(
meta_data=meta_data,
max_workers=max_workers,
gids=gids,
).to_netcdf(out_file)
).load().to_netcdf(out_file)
logger.info(f'Saved output of size {data.shape} to: {out_file}')
7 changes: 4 additions & 3 deletions sup3r/preprocessing/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,10 @@ def ordered(self, data):
return data.transpose(*ordered_dims(data.dims), ...)

def sample(self, idx):
"""Get sample from self._ds. The idx should be a tuple of slices for
the dimensions (south_north, west_east, time) and a list of feature
names."""
"""Get sample from ``self._ds``. The idx should be a tuple of slices
for the dimensions ``(south_north, west_east, time)`` and a list of
feature names. e.g.
``(slice(0, 3), slice(1, 10), slice(None), ['u_10m', 'v_10m'])``"""
isel_kwargs = dict(zip(Dimension.dims_3d(), idx[:-1]))
features = (
self.features if not is_strings(idx[-1]) else _lowered(idx[-1])
Expand Down
Loading

0 comments on commit d8e0fe2

Please sign in to comment.