Skip to content

Commit

Permalink
Allow for computing multiple versions of the same statistical operator
Browse files Browse the repository at this point in the history
Reworked the compute statistics config and operators to be able to compute different versions of the same statistical operator, e.g. applying the mean over different dimensions.
  • Loading branch information
mafdmi committed Nov 21, 2024
1 parent e257056 commit 36cb5d9
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 47 deletions.
40 changes: 26 additions & 14 deletions example.danra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,32 @@ output:
start: 1990-09-03T00:00
end: 1990-09-06T00:00
compute_statistics:
MeanStatistic:
dims: [grid_index, time]
StdStatistic:
dims: [grid_index, time]
DiffMeanStatistic:
dims: [grid_index, time]
DiffTimeMeanStatistic:
dims: [time]
DiffStdStatistic:
dims: [grid_index, time]
DiurnalDiffMeanStatistic:
dims: [grid_index, time]
DiurnalDiffStdStatistic:
dims: [grid_index, time]
MeanOperator:
- name: mean
dims: [grid_index, time]
StdOperator:
- name: std
dims: [grid_index, time]
DiffMeanOperator:
- name: diff_mean
dims: [grid_index, time]
- name: diff_time_mean
dims: [time]
DiffStdOperator:
- name: diff_std
dims: [grid_index, time]
- name: diff_time_std
dims: [time]
DiurnalDiffMeanOperator:
- name: diurnal_diff_mean
dims: [grid_index, time]
- name: diurnal_diff_time_mean
dims: [time]
DiurnalDiffStdOperator:
- name: diurnal_diff_std
dims: [grid_index, time]
- name: diurnal_diff_time_std
dims: [time]
val:
start: 1990-09-06T00:00
end: 1990-09-07T00:00
Expand Down
3 changes: 2 additions & 1 deletion mllam_data_prep/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class Statistic:
The dimensions to compute the statistics over, e.g. ["time", "grid_index"].
"""

name: str
dims: List[str]


Expand All @@ -192,7 +193,7 @@ class Split:

start: str
end: str
compute_statistics: Dict[str, Statistic] = None
compute_statistics: Dict[str, List[Statistic]] = None


@dataclass
Expand Down
43 changes: 11 additions & 32 deletions mllam_data_prep/ops/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict
from typing import Dict, List

import xarray as xr

from ..config import Statistic


def calc_stats(
ds: xr.Dataset, statistic_configs: Dict[str, Statistic], splitting_dim: str
ds: xr.Dataset, statistic_configs: Dict[str, List[Statistic]], splitting_dim: str
) -> Dict[str, xr.Dataset]:
"""
Calculate statistics for a given DataArray by applying the operations
Expand All @@ -35,10 +35,15 @@ def calc_stats(
Dictionary with the operation names as keys and the calculated statistics as values
"""
stats = {}
for stat_name, statistic in statistic_configs.items():
for stat_name, statistics in statistic_configs.items():
if stat_name in globals():
stat: StatisticOperator = globals()[stat_name](ds, splitting_dim)
stats[stat.name] = stat.calc_stats(statistic.dims)
# Apply the operation to the dataset (multiple different configurations
# of the same operator can be applied)
for statistic in statistics:
operator: StatisticOperator = globals()[stat_name](
ds=ds, splitting_dim=splitting_dim, name=statistic.name
)
stats[statistic.name] = operator.calc_stats(statistic.dims)
else:
raise NotImplementedError(stat_name)

Expand All @@ -62,11 +67,7 @@ class StatisticOperator(ABC):

ds: xr.Dataset
splitting_dim: str

@property
@abstractmethod
def name(self):
"""Override property to specify the name of the statistic"""
name: str

@abstractmethod
def calc_stats(self, dims):
Expand All @@ -76,26 +77,20 @@ def calc_stats(self, dims):
class MeanOperator(StatisticOperator):
"""Calculate the mean along the specified dimensions."""

name = "mean"

def calc_stats(self, dims):
return self.ds.mean(dim=dims)


class StdOperator(StatisticOperator):
"""Calculate the standard deviation along the specified dimensions."""

name = "std"

def calc_stats(self, dims):
return self.ds.std(dim=dims)


class DiffMeanOperator(StatisticOperator):
"""Calculate the mean of the differences along the specified dimensions."""

name = "diff_mean"

def calc_stats(self, dims):
vars_to_keep = [
v for v in self.ds.data_vars if self.splitting_dim in self.ds[v].dims
Expand All @@ -107,8 +102,6 @@ def calc_stats(self, dims):
class DiffStdOperator(StatisticOperator):
"""Calculate std of the differences along the specified dimensions."""

name = "diff_std"

def calc_stats(self, dims):
vars_to_keep = [
v for v in self.ds.data_vars if self.splitting_dim in self.ds[v].dims
Expand All @@ -117,21 +110,9 @@ def calc_stats(self, dims):
return ds_diff.std(dim=dims)


class DiffTimeMeanOperator(DiffMeanOperator):
"""Calculate the mean of the differences along the time dimension.
This is a duplicate of the DiffMeanOperator to allow for averaging over
other dimensions.
"""

name = "diff_time_mean"


class DiurnalDiffMeanOperator(StatisticOperator):
"""Calculate the mean of the diurnal differences along the specified dimensions."""

name = "diurnal_mean"

def calc_stats(self, dims):
vars_to_keep = [
v for v in self.ds.data_vars if self.splitting_dim in self.ds[v].dims
Expand All @@ -146,8 +127,6 @@ def calc_stats(self, dims):
class DiurnalDiffStdOperator(StatisticOperator):
"""Calculate the std of the diurnal differences along the specified dimensions."""

name = "diurnal_std"

def calc_stats(self, dims):
vars_to_keep = [
v for v in self.ds.data_vars if self.splitting_dim in self.ds[v].dims
Expand Down

0 comments on commit 36cb5d9

Please sign in to comment.