Skip to content

Commit

Permalink
Merge pull request #304 from jmccreight/feat_subset_netcdf
Browse files Browse the repository at this point in the history
Feat subset netcdf
  • Loading branch information
jmccreight authored Jul 12, 2024
2 parents 7b909a2 + c2ac544 commit 4abf7e8
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 2 deletions.
103 changes: 103 additions & 0 deletions autotest/test_netcdf_subset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest
import xarray as xr

import pywatershed as pws

file_types = (
"parameters_dis_hru.nc",
"tmax.nc",
"output/sroff.nc",
)


@pytest.fixture(scope="function")
def nhm_ids(simulation):
domain_name = simulation["name"].split(":")[0]
if domain_name not in ["hru_1", "drb_2yr"]:
pytest.skip("Only test_netcdf_subset hru_1 and drb_2yr")
if domain_name == "hru_1":
subset_inds = (0,)
else:
subset_inds = (0, 10, 100)

param_file = simulation["dir"] / "parameters_dis_hru.nc"
param_ds = xr.open_dataset(param_file)
return param_ds.nhm_id.values[subset_inds,]


@pytest.mark.parametrize("file_type", file_types)
def test_subset_netcdf_file(simulation, file_type, nhm_ids, tmp_path):
# do time in the test on the dataset not the file
old_file = simulation["dir"] / file_type
new_file = tmp_path / "new.nc"
pws.utils.netcdf_utils.subset_netcdf_file(
old_file,
new_file,
coord_dim_name="nhm_id",
coord_dim_values_keep=nhm_ids,
)

old_ds = xr.open_dataset(old_file)
new_ds = xr.open_dataset(new_file)

assert set(new_ds.nhm_id.values) == set(nhm_ids)

assert set(new_ds.variables) == set(old_ds.variables)

for vv in new_ds:
assert new_ds[vv].dims == old_ds[vv].dims


@pytest.mark.parametrize("file_type", file_types)
def test_subset_xr_ds(simulation, file_type, nhm_ids, tmp_path):
start_time = end_time = None
old_file = simulation["dir"] / file_type
old_ds = xr.open_dataset(old_file)
if "time" in old_ds.variables:
start_time = old_ds.time[100]
end_time = old_ds.time[125]
# create a variables without time and nhm_id dimensions
data_var_names = list(old_ds.data_vars)
old_ds["dum_var"] = old_ds[data_var_names[0]].isel(time=0)
old_ds["some_var"] = old_ds[data_var_names[0]][:, 0].squeeze()

new_ds = pws.utils.netcdf_utils.subset_xr(
old_ds,
start_time=start_time,
end_time=end_time,
coord_dim_name="nhm_id",
coord_dim_values_keep=nhm_ids,
)

assert set(new_ds.nhm_id.values) == set(nhm_ids)

assert set(new_ds.variables) == set(old_ds.variables)

for vv in new_ds:
assert new_ds[vv].dims == old_ds[vv].dims

if start_time is not None:
assert len(new_ds.time == 26)


@pytest.mark.parametrize("file_type", file_types[1:])
def test_subset_xr_da(simulation, file_type, nhm_ids, tmp_path):
start_time = end_time = None
old_file = simulation["dir"] / file_type
old_da = xr.open_dataarray(old_file)
start_time = old_da.time[100]
end_time = old_da.time[125]

new_da = pws.utils.netcdf_utils.subset_xr(
old_da,
start_time=start_time,
end_time=end_time,
coord_dim_name="nhm_id",
coord_dim_values_keep=nhm_ids,
)

assert set(new_da.nhm_id.values) == set(nhm_ids)

assert new_da.dims == old_da.dims

assert len(new_da.time == 26)
2 changes: 2 additions & 0 deletions doc/api/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ Utils
ControlVariables
MmrToMf6Dfw
utils.cbh_file_to_netcdf
utils.netcdf_utils.subset_netcdf_file
utils.netcdf_utils.subset_xr
6 changes: 5 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ New Features
(:pull:`288`) By `James McCreight <https://github.com/jmccreight>`_.
- Control instances have a diff method to compare with other instances.
(:pull:`288`) By `James McCreight <https://github.com/jmccreight>`_.
- Feature to standardize subsetting input data (parameters and forcings) in
space and time either from file (:func:`utils.netcdf_utils.subset_netcdf_file`) or
in memory (:func:`utils.netcdf_utils.subset_xr`).
(:pull:`304`) By `James McCreight <https://github.com/jmccreight>`_.

Breaking Changes
~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -110,7 +114,7 @@ New features
.. _whats-new.1.0.0:

v1.0.0 (18 December 2023)
---------------------
-------------------------

New features
~~~~~~~~~~~~
Expand Down
3 changes: 3 additions & 0 deletions pywatershed/base/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,8 +574,11 @@ def subset_on_coord(
None
"""
# only doing it in place for now
# should we reall roll our own?

# TODO: should almost work for 2+D? just linearizes np.where
# except that dims should be droped with >1

if len(where) > 1:
raise NotImplementedError("at least not tested")

Expand Down
2 changes: 1 addition & 1 deletion pywatershed/hydrology/prms_channel_flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ def prms_channel_flow_graph_to_model_dict(
Note that if you want to run a :class:`FlowGraph` by itself, simply
forced by known inflows (and not in the context of other
:class:`Process`\ es in a :class:`Model`), then the helper function
:class:`base.Process`\ es in a :class:`Model`), then the helper function
:func:`prms_channel_flow_graph_postprocess` is for you.
Please see the example notebook `examples/06_flow_graph_starfit.ipynb <https://github.com/EC-USGS/pywatershed/blob/develop/examples/06_flow_graph_starfit.ipynb>`__
Expand Down
138 changes: 138 additions & 0 deletions pywatershed/utils/netcdf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import netCDF4 as nc4
import numpy as np
import xarray as xr

from ..base.accessor import Accessor
from ..base.meta import meta_dimensions, meta_netcdf_type
Expand Down Expand Up @@ -622,3 +623,140 @@ def add_all_data(
self.variables[name][:, :] = data[:, :]

return


def subset_netcdf_file(
file_name: Union[pl.Path, str],
new_file_name: Union[pl.Path, str],
start_time: np.datetime64 = None,
end_time: np.datetime64 = None,
coord_dim_name: str = None,
coord_dim_values_keep: np.ndarray = None,
) -> None:
"""Subset a netcdf file on to coordinate or dimension values.
Args:
file_name: The name/path of the input file.
new_file_name: The name/path of the output file.
start_time: Optional start time if a "time" coord is present.
end_time: Optional end time if a "time" coord is present.
coord_dim_name: Optional coord or dimension name to subset on.
coord_dim_values_keep: Optional values on the coord or dimension to
retain in teh subset.
This currently works for 1-D coordinates, more dimensions not tested.
Note: This uses the function
:func:`pywatershed.utils.netcdf_utils.subset_xr`
under the hood, which can be called if you want to subset xr.Datasets in
memory. There seem to beseveral edge cases lurking around here with zero
length dimensions and xarray's broadcasting rules. This function is a
convenience function because xarray's functionality is not ideal for our
use cases and is confusing with pitfalls. See
https://github.com/pydata/xarray/issues/8796
for additional discussion.
"""
ds = xr.load_dataset(file_name)

ds = subset_xr(
ds=ds,
start_time=start_time,
end_time=end_time,
coord_dim_name=coord_dim_name,
coord_dim_values_keep=coord_dim_values_keep,
)

ds.to_netcdf(new_file_name)

return


def subset_xr(
ds: Union[xr.Dataset, xr.DataArray],
start_time: np.datetime64 = None,
end_time: np.datetime64 = None,
coord_dim_name: str = None,
coord_dim_values_keep: np.array = None,
) -> Union[xr.Dataset, xr.DataArray]:
"""Subset an xarray Dataset or DataArray on to coord or dim values.
Args:
start_time: Optional start time if a "time" coord is present.
end_time: Optional end time if a "time" coord is present.
coord_dim_name: Optional coord or dimension name to subset on.
coord_dim_values_keep: Optional values on the coord or dimension to
retain in teh subset.
This currently works for 1-D coordinates, more dimensions not tested.
To work with files rather than memory see
:func:`pywatershed.utils.netcdf_utils.subset_netcdf_file`.
Note: There seem to be several edge cases lurking around here with zero
length dimensions and xarray's broadcasting rules. This function is a
convenience function because xarray's functionality is not ideal for our
use cases and is confusing with pitfalls. See
https://github.com/pydata/xarray/issues/8796 for additional discussion.
"""
if isinstance(ds, xr.DataArray):
var_dims_orig = ds.dims
else:
var_dims_orig = {key: ds[key].dims for key in ds.variables}

if coord_dim_name is not None or coord_dim_values_keep is not None:
msg = (
"Neither or both of coord_dim_name and coord_dim_values_keep "
"must be supplied."
)
assert (
coord_dim_name is not None and coord_dim_values_keep is not None
), msg

# <
if coord_dim_name is not None:
msg = f"{coord_dim_values_keep=} not in {coord_dim_name=}"
assert ds[coord_dim_name].isin(coord_dim_values_keep).any(), msg

ds = ds.where(
ds[coord_dim_name].isin(coord_dim_values_keep), drop=True
)

if isinstance(ds, xr.DataArray):
dims_orig = set(var_dims_orig)
dims_new = set(ds.dims)
extra_dims = list(dims_new - dims_orig)
if len(extra_dims):
for dd in extra_dims:
ds = ds.isel({dd: 0}).squeeze()
else:
for var in list(ds.variables):
dims_orig = set(var_dims_orig[var])
dims_new = set(ds[var].dims)
extra_dims = list(dims_new - dims_orig)
# if "scalar" in dims_orig:
# asdf
if len(extra_dims):
# a headache to deal with when it broadcasts to a zero
# or non-zero length dimension
dim_dict = dict(zip(ds[var].dims, ds[var].shape))
extra_dim_lens = np.array(
[dim_dict[dd] for dd in extra_dims]
)
if (extra_dim_lens <= 1).all():
ds[var] = ds[var].squeeze(extra_dims)
else:
for dd in extra_dims:
if dd in ds[var].dims:
ds[var] = ds[var].isel({dd: 0}).squeeze()

# <<<
if start_time is not None or end_time is not None:
msg = "Neither or both of start_time and end_time must be supplied."
assert start_time is not None and end_time is not None, msg

# <
# does sel work correctly here?
if start_time is not None:
if "time" in ds.dims:
ds = ds.sel(time=slice(start_time, end_time))

return ds

0 comments on commit 4abf7e8

Please sign in to comment.