Skip to content

Commit

Permalink
Merge branch 'main' into regrid2_performance
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonb5 committed Jan 19, 2024
2 parents 3312be3 + fbf1db6 commit 9377a88
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 27 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ markers = ["flaky", "network"]

[tool.mypy]
# Docs: https://mypy.readthedocs.io/en/stable/config_file.html
python_version = 3.10
python_version = "3.10"
check_untyped_defs = true
ignore_missing_imports = true
warn_unused_ignores = true
Expand Down
49 changes: 23 additions & 26 deletions xcdat/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def departures(
# The original time dimension name is restored after grouped
# arithmetic, so the labeled time dimension name is no longer needed
# and therefore dropped.
ds_obs = ds_obs.drop_vars(self._labeled_time.name)
ds_obs = ds_obs.drop_vars(str(self._labeled_time.name))

if weighted and keep_weights:
self._weights = ds_climo.time_wts
Expand All @@ -757,21 +757,17 @@ def _averager(
# Preprocess the dataset based on method argument values.
ds = self._preprocess_dataset(ds)

# Get the data variable and the required time axis metadata.
dv = _get_data_var(ds, data_var)
time_bounds = ds.bounds.get_bounds("T", var_key=dv.name)

if self._mode == "average":
dv = self._average(dv, time_bounds)
dv_avg = self._average(ds, data_var)
elif self._mode in ["group_average", "climatology", "departures"]:
dv = self._group_average(dv, time_bounds)
dv_avg = self._group_average(ds, data_var)

# The original time dimension is dropped from the dataset because
# it becomes obsolete after the data variable is averaged. When the
# averaged data variable is added to the dataset, the new time dimension
# and its associated coordinates are also added.
ds = ds.drop_dims(self.dim) # type: ignore
ds[dv.name] = dv
ds[dv_avg.name] = dv_avg

if keep_weights:
ds = self._keep_weights(ds)
Expand Down Expand Up @@ -1075,28 +1071,28 @@ def _drop_leap_days(self, ds: xr.Dataset):
)
return ds

def _average(
self, data_var: xr.DataArray, time_bounds: xr.DataArray
) -> xr.DataArray:
def _average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
"""Averages a data variable with the time dimension removed.
Parameters
----------
data_var : xr.DataArray
The data variable.
time_bounds : xr.DataArray
The time bounds.
ds : xr.Dataset
The dataset.
data_var : str
The key of the data variable.
Returns
-------
xr.DataArray
The averages for a data variable with the time dimension removed.
The data variable averaged with the time dimension removed.
"""
dv = data_var.copy()
dv = _get_data_var(ds, data_var)

with xr.set_options(keep_attrs=True):
if self._weighted:
time_bounds = ds.bounds.get_bounds("T", var_key=data_var)
self._weights = self._get_weights(time_bounds)

dv = dv.weighted(self._weights).mean(dim=self.dim) # type: ignore
else:
dv = dv.mean(dim=self.dim) # type: ignore
Expand All @@ -1105,31 +1101,31 @@ def _average(

return dv

def _group_average(
self, data_var: xr.DataArray, time_bounds: xr.DataArray
) -> xr.DataArray:
def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
"""Averages a data variable by time group.
Parameters
----------
data_var : xr.DataArray
The data variable.
time_bounds : xr.DataArray
The time bounds.
ds : xr.Dataset
The dataset.
data_var : str
The key of the data variable.
Returns
-------
xr.DataArray
The data variable averaged by time group.
"""
dv = data_var.copy()
dv = _get_data_var(ds, data_var)

# Label the time coordinates for grouping weights and the data variable
# values.
self._labeled_time = self._label_time_coords(dv[self.dim])

if self._weighted:
time_bounds = ds.bounds.get_bounds("T", var_key=data_var)
self._weights = self._get_weights(time_bounds)

# Weight the data variable.
dv *= self._weights

Expand All @@ -1145,8 +1141,9 @@ def _group_average(
# included to take into account zero weight for missing data.
with xr.set_options(keep_attrs=True):
dv = self._group_data(dv).sum() / self._group_data(weights).sum()

# Restore the data variable's name.
dv.name = data_var.name
dv.name = data_var
else:
dv = self._group_data(dv).mean()

Expand Down

0 comments on commit 9377a88

Please sign in to comment.